diff --git a/.github/ISSUE_TEMPLATE/bug_report.md b/.github/ISSUE_TEMPLATE/bug_report.md index 86322616fa..8165ec95fc 100644 --- a/.github/ISSUE_TEMPLATE/bug_report.md +++ b/.github/ISSUE_TEMPLATE/bug_report.md @@ -20,4 +20,4 @@ A clear and concise description of what you expected to happen. - Environment location: [Bare-metal, Docker, Cloud(specify cloud provider)] **Additional context** -Add any other context about the problem here. \ No newline at end of file +Add any other context about the problem here. diff --git a/.github/ISSUE_TEMPLATE/config.yml b/.github/ISSUE_TEMPLATE/config.yml new file mode 100644 index 0000000000..4572ae1b98 --- /dev/null +++ b/.github/ISSUE_TEMPLATE/config.yml @@ -0,0 +1,5 @@ +blank_issues_enabled: true +contact_links: + - name: CUTLASS Discord + url: https://discord.gg/nvidiadeveloper + about: Come chat about using and contributing to CUTLASS! diff --git a/.github/ISSUE_TEMPLATE/documentation_request.md b/.github/ISSUE_TEMPLATE/documentation_request.md index 9e96105f5d..c9fa21fac9 100644 --- a/.github/ISSUE_TEMPLATE/documentation_request.md +++ b/.github/ISSUE_TEMPLATE/documentation_request.md @@ -32,4 +32,4 @@ A clear and concise description of what documentation you believe it is needed a A clear and concise description of what you want to happen. **Steps taken to search for needed documentation** -List any steps you have taken: \ No newline at end of file +List any steps you have taken: diff --git a/.github/ISSUE_TEMPLATE/submit_question.md b/.github/ISSUE_TEMPLATE/submit_question.md index 743f893fcb..5aa2a672d2 100644 --- a/.github/ISSUE_TEMPLATE/submit_question.md +++ b/.github/ISSUE_TEMPLATE/submit_question.md @@ -7,4 +7,4 @@ assignees: '' --- -**What is your question?** \ No newline at end of file +**What is your question?** diff --git a/.github/workflows/labeler.yml b/.github/workflows/labeler.yml index 6510938e85..23956a02fb 100644 --- a/.github/workflows/labeler.yml +++ b/.github/workflows/labeler.yml @@ -8,4 +8,4 @@ jobs: steps: - uses: actions/labeler@main with: - repo-token: "${{ secrets.GITHUB_TOKEN }}" \ No newline at end of file + repo-token: "${{ secrets.GITHUB_TOKEN }}" diff --git a/.github/workflows/new-issues-to-triage-projects.yml b/.github/workflows/new-issues-to-triage-projects.yml index 3049176e3b..a963cb2f89 100644 --- a/.github/workflows/new-issues-to-triage-projects.yml +++ b/.github/workflows/new-issues-to-triage-projects.yml @@ -32,4 +32,4 @@ jobs: env: GITHUB_TOKEN: ${{ secrets.GITHUB_TOKEN }} GITHUB_PROJECT_URL: https://github.com/NVIDIA/cutlass - GITHUB_PROJECT_COLUMN_NAME: 'Needs prioritizing' \ No newline at end of file + GITHUB_PROJECT_COLUMN_NAME: 'Needs prioritizing' diff --git a/.github/workflows/stale.yml b/.github/workflows/stale.yml index cb2e7275bd..8b65da69aa 100644 --- a/.github/workflows/stale.yml +++ b/.github/workflows/stale.yml @@ -54,4 +54,4 @@ jobs: exempt-pr-labels: "0 - Blocked,0 - Backlog,good first issue" days-before-pr-stale: 90 days-before-pr-close: -1 - operations-per-run: 50 \ No newline at end of file + operations-per-run: 50 diff --git a/.gitignore b/.gitignore index 1328f6b7d6..acddb1f9d1 100644 --- a/.gitignore +++ b/.gitignore @@ -1,2 +1,3 @@ # PyCache files __pycache__/ +cutlass_library.egg-info/ \ No newline at end of file diff --git a/CHANGELOG.md b/CHANGELOG.md index d09b4981e4..1ba870eba2 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,57 +1,270 @@ # NVIDIA CUTLASS Changelog +## [3.6.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.6.0) (2024-10-03) + +- [Hopper structured sparse GEMM](./examples/62_hopper_sparse_gemm/62_hopper_sparse_gemm.cu). + + [FP16](./test/unit/gemm/device/sm90_sparse_gemm_f16_f16_f32_tensor_op_f32.cu) + + [FP8](./test/unit/gemm/device/sm90_sparse_gemm_f8_f8_f32_tensor_op_f32.cu) + + [INT8](./test/unit/gemm/device/sm90_sparse_gemm_s8_s8_s32_tensor_op_s32.cu) + + [TF32](./test/unit/gemm/device/sm90_sparse_gemm_tf32_tf32_f32_tensor_op_f32.cu) +- A refactor to the CUTLASS 3.x convolution `kernel::ConvUniversal` [API](./include/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp) to bring it in line with `gemm::GemmUniversal`. Now the 3.x convolution API is no longer considered as a beta API. +- Improve [mixed input GEMM](./examples/55_hopper_mixed_dtype_gemm/README.md). + + Added a [lookup table implementation](./examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu) for `INT4`x`FP8` scale-only mode. + + Added [layout pre-shuffling](./examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu#L50-55) to optimize memory loading. + + Added [interleaved conversion](./examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_bf16_gemm.cu#L50-52) for `{INT4, UINT4, INT8}` x `{FP16, BF16}`. + + Other general optimizations. +- The suffixes of the mixed input kernel schedules have been removed. Use `KernelTmaWarpSpecialized`, `KernelTmaWarpSpecializedPingpong` and `KernelTmaWarpSpecializedCooperative` instead. +- [EVT nodes for Top-K selection and softmax](./include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp) and [GEMM example using those](./examples/61_hopper_gemm_with_topk_and_softmax/61_hopper_gemm_with_topk_and_softmax.cu). +- [Programmatic Dependent Launch](./include/cutlass/arch/grid_dependency_control.h) (PDL) that leverages a new Hopper feature to speedup two back-to-back kernels, and its corresponding [documentations](./media/docs/dependent_kernel_launch.md). +- [A new debugging tool, synclog](./include/cutlass/arch/synclog.hpp), for dumping out all synchronization events from within a kernel to a file. Please see [synclog documentation](./media/docs/utilities.md#debugging-asynchronous-kernels-with-cutlasss-built-in-synclog-tool) for details. +- A new TMA-enabled [epilogue](./include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp) for grouped GEMM that brings significant performance improvement, as well as its EVT support. +- A SIMT-enabled pointer-array [epilogue](./include/cutlass/epilogue/collective/sm70_epilogue_vectorized_array.hpp). +- A new [Ping-Pong kernel schedule for Grouped GEMM](./include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp) and some other optimizations. +- [A new instantiation strategy for CUTLASS profiler kernels](./python/cutlass_library/sm90_shapes.py) along with [improved documentation for instantiation level in CUTLASS profiler](./media/docs/profiler.md#instantiating-more-kernels-with-hopper). +- A new hardware support for comparisons and computations of [`cutlass::bfloat16_t`](./include/cutlass/bfloat16.h) +- Fixed use of isnan on Windows for [`half_t`](./test/unit/core/functional.cu). +- Various improvements and fixes from the community and CUTLASS team. Thanks to everyone who submitted PRs! + +- [Minimal SM90 WGMMA + TMA GEMM example in 100 lines of code](./examples/cute/tutorial/wgmma_sm90.cu) +- [Exposure of L2 `cache_hint`s in TMA copy atoms](./include/cute/arch/copy_sm90_tma.hpp#L48) +- Exposure of raster order and tile swizzle extent in [CUTLASS library profiler](./media/docs/profiler.md#GEMM), and +[example 48](./examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu). +- [TMA store based and EVT supported epilogues](./include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp) for [Hopper pointer array batched kernels](./test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_ptr_array.cu). +- A new [`GemmSparseUniversal` API for CUTLASS 2.x Ampere kernels](./include/cutlass/gemm/device/gemm_sparse_universal.h) to enable serial and parallel split-k for sparse tensor cores and new tiny tile sizes to better support LLM inferrence: + + [FP16 TN](./test/unit/gemm/device/gemm_f16t_f16n_f32t_tensor_op_f32_sparse_sm80.cu#L269-L393) and [NT](./test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sparse_sm80.cu#L269-L411). + + [int8 TN](./test/unit/gemm/device/gemm_s8t_s8n_s32t_tensor_op_s32_sparse_sm80.cu#L264-L452). + + [int4 TN](./test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sparse_sm80.cu#L264-L452). + + [FP32 TN](./test/unit/gemm/device/gemm_f32t_f32n_f32t_tensor_op_f32_sparse_sm80.cu#L427-L642) and [NT](./test/unit/gemm/device/gemm_f32n_f32t_f32t_tensor_op_f32_sparse_sm80.cu#L427-L456). +- [CUDA host adapter](./include/cutlass/cuda_host_adapter.hpp) extensions to support TMA descriptor construction driver APIs. +- Inclusion of more [Hopper fprop, dgrad, and wgrad convolution kernels in CUTLASS library and profiler](./python/cutlass_library/generator.py). +- Support for residual add (beta != 0) in convolution kernels. +- A new convolution [epilogue](./examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu#L269) for CUTLASS 2.x to support non-packed NHWC output. +- A refactor of [include files throughout CUTLASS core directories](./include/cutlass/gemm/collective/collective_mma_decl.hpp) to reduce circular dependencies and [tests to guard against them](./test/self_contained_includes/CMakeLists.txt). +- [A guide for setting up VSCode to work well with CUTLASS](./media/docs/ide_setup.md) and [expanded code style guide](./media/docs/programming_guidelines.md). +- Better support for MSVC as a host compiler. +- Many performance optimizations, improvements, and bug fixes including fixes for FlashAttention-2. +- Optimal code generation with CUDA toolkit versions 12.4 and 12.5u1. + +## [3.5.1](https://github.com/NVIDIA/cutlass/releases/tag/v3.5.1) (2024-07-25) + +- [Minimal SM90 WGMMA + TMA GEMM example in 100 lines of code](./examples/cute/tutorial/wgmma_sm90.cu) +- [Exposure of L2 `cache_hint`s in TMA copy atoms](./include/cute/arch/copy_sm90_tma.hpp#L48) +- Exposure of raster order and tile swizzle extent in [CUTLASS library profiler](./media/docs/profiler.md#GEMM), and +[example 48](./examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu). +- [TMA store based and EVT supported epilogues](./include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp) for [Hopper pointer array batched kernels](./test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_ptr_array.cu). +- A new [`GemmSparseUniversal` API for CUTLASS 2.x Ampere kernels](./include/cutlass/gemm/device/gemm_sparse_universal.h) to enable serial and parallel split-k for sparse tensor cores and new tiny tile sizes to better support LLM inferrence: + + [FP16 TN](./test/unit/gemm/device/gemm_f16t_f16n_f32t_tensor_op_f32_sparse_sm80.cu#L269-L393) and [NT](./test/unit/gemm/device/gemm_f16n_f16t_f32t_tensor_op_f32_sparse_sm80.cu#L269-L411). + + [int8 TN](./test/unit/gemm/device/gemm_s8t_s8n_s32t_tensor_op_s32_sparse_sm80.cu#L264-L452). + + [int4 TN](./test/unit/gemm/device/gemm_s4t_s4n_s32t_tensor_op_s32_sparse_sm80.cu#L264-L452). + + [FP32 TN](./test/unit/gemm/device/gemm_f32t_f32n_f32t_tensor_op_f32_sparse_sm80.cu#L427-L642) and [NT](./test/unit/gemm/device/gemm_f32n_f32t_f32t_tensor_op_f32_sparse_sm80.cu#L427-L456). +- [CUDA host adapter](./include/cutlass/cuda_host_adapter.hpp) extensions to support TMA descriptor construction driver APIs. +- Inclusion of more [Hopper fprop, dgrad, and wgrad convolution kernels in CUTLASS library and profiler](./python/cutlass_library/generator.py). +- Support for residual add (beta != 0) in convolution kernels. +- A new convolution [epilogue](./examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu#L269) for CUTLASS 2.x to support non-packed NHWC output. +- A refactor of [include files throughout CUTLASS core directories](./include/cutlass/gemm/collective/collective_mma_decl.hpp) to reduce circular dependencies and [tests to guard against them](./test/self_contained_includes/CMakeLists.txt). +- [A guide for setting up VSCode to work well with CUTLASS](./media/docs/ide_setup.md) and [expanded code style guide](./media/docs/programming_guidelines.md). +- Better support for MSVC as a host compiler. +- Many performance optimizations, improvements, and bug fixes including fixes for FlashAttention-2. +- Optimal code generation with CUDA toolkit versions 12.4 and 12.5u1. + +## [3.5.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.5.0) (2024-04-09) + +- Implicit GEMM Convolutions targeting Hopper SM90A via WGMMA + [TMA im2col](./include/cute/atom/copy_traits_sm90_im2col.hpp) + + Native implementation in CUTLASS 3.x using CuTe, mirroring the [same design hierarchy as that of GEMMs](./media/docs/gemm_api_3x.md). + + Support for 1D, 2D, and 3D convolutions in a [rank-agnostic fashion](./include/cutlass/conv/convnd_problem_shape.hpp). + + Support for [Fprop](./test/unit/conv/device_3x/fprop/sm90_conv3d_fprop_implicit_gemm_s8_s8_s32_tensorop_s32.cu), [Dgrad](./test/unit/conv/device_3x/dgrad/sm90_conv2d_dgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu), and [Wgrad](./test/unit/conv/device_3x/wgrad/sm90_conv1d_wgrad_implicit_gemm_f16_f16_f32_tensorop_f16.cu) algorithms + + [CUTLASS profiler support](./python/cutlass_library/conv3x_emitter.py) for 2D and 3D convolutions implemented via the 3.x API. + + NOTE: this is a beta release. Further updates to CUTLASS will include major performance improvements, feature enablement, and possible breaking changes to the API until 3.7 release. Your feedback is welcome on the design! +- Support for [Ada (SM89) FP8 tensor cores via the 2.x API](./examples/58_ada_fp8_gemm/ada_fp8_gemm.cu). Requires CUDA 12.4 or newer. +- [Ampere gather/scatter convolution example](./examples/59_ampere_gather_scatter_conv/README.md) in CuTe and CUTLASS 3.x + + Showcasing how custom kernels can be written and optimized using CUTLASS 3.x and CuTe and the general strategy for implementing convolutions as specializations of GETTs. + + Implementation of a coarse grained sparse gather/scatter kernel achieving peak performance on Ampere class tensor cores. +- 32x and 16x tile sizes are added to CUTLASS 2.x to improve the performance of narrow-tall and wide-short matrices. + + [Ampere FP16 TN](./test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f32_sm80.cu) and [NT](./test/unit/gemm/device/gemm_f16n_f16t_f16t_tensor_op_f32_sm80.cu#L227-L301), [Ampere INT8 TN](./test/unit/gemm/device/gemm_s8t_s8n_s8t_tensor_op_s32_sm80.cu#L392-L1342), [Ampere INT4 TN](./test/unit/gemm/device/gemm_s4t_s4n_s4t_tensor_op_s32_sm80.cu#L372-L934). + + [Turing FP16 TN](./test/unit/gemm/device/gemm_f16t_f16n_f16t_tensor_op_f32_sm75.cu#L55-L394), [Turing INT8 TN](./test/unit/gemm/device/gemm_s8t_s8n_s8t_tensor_op_s32_sm75.cu#L166-L537), [Turing INT4 TN](./test/unit/gemm/device/gemm_s4t_s4n_s4t_tensor_op_s32_sm75.cu#L310-L564). +- Updates to CuTe documentation for [`cute::Tensor<>`](./media/docs/cute/03_tensor.md), [MMA atoms](./media/docs/cute/0t_mma_atom.md), and an overhauled [CuTe GEMM tutorial series](./examples/cute/tutorial). +- Extensions to CuTe to support [L2 prefetching](./include/cute/algorithm/prefetch.hpp) and [TMA store+reductions](./include/cute/arch/copy_sm90_tma.hpp#L1337). +- Remove C++11 requirement on a few CUTLASS 2.x API header files. All CUTLASS files now require C++17. +- Fixes to greatly reduce build warnings. +- Updates and bugfixes from the community (thanks!) + +## [3.4.1](https://github.com/NVIDIA/cutlass/releases/tag/v3.4.1) (2024-02-14) + +- Statically available [CUTLASS Version macros](./include/cutlass/version.h) that allow for handling API changes between CUTLASS releases on the users' side. +- Improvements for Hopper [Group-GEMMs](./examples/57_hopper_grouped_gemm) and [Pointer-Array Batched GEMMs](./examples/56_hopper_ptr_array_batched_gemm). +- Updates and bugfixes from the community (thanks!). + +## [3.4.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.4.0) (2024-01-12) +* Expanded [Mixed-input Hopper GEMMs](./examples/55_hopper_mixed_dtype_gemm) support covering {16-bit, 8-bit} x {8-bit, 4-bit} input types with fast numerical converters and group scaling factors. +* Performance improvements to [Mixed-input Hopper GEMMs](./examples/55_hopper_mixed_dtype_gemm) +* Beta release of [Pointer-Array Batched GEMMs](./examples/56_hopper_ptr_array_batched_gemm) now available on Hopper GPUs utilizing TMA and WGMMA (requires CUDA 12.3 or above). +* Beta release of [Group-GEMM](./examples/57_hopper_grouped_gemm) utilizing TMA and WGMMA (requires CUDA 12.3 or above). +* [Ampere Sparse GEMM](./examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm_with_visitor.cu) supports Epilogue Visitor Tree (EVT) now. +* NamedBarriers usability improvement and list of [ReservedNamedBarriers](./include/cutlass/arch/barrier.h) has been officially released. +* Improved [CuTe documentation](./media/docs/cute/) including improved clarity and depth of [Quickstart](./media/docs/cute/00_quickstart.md), [CuTe Layout](./media/docs/cute/01_layout.md), and [CuTe Layout Algebra](./media/docs/cute/02_layout_algebra.md). Associated code comments, post-conditions, and details in [CuTe Core Unit Tests](./test/unit/cute/core/) also improved. + +## [3.3](https://github.com/NVIDIA/cutlass/releases/tag/v3.3.0) (2023-10-31) +* [Mixed-input Hopper GEMMs](./examples/55_hopper_mixed_dtype_gemm) support covering 16-bit x 8-bit input operand types. +* [Mixed-input Ampere GEMMs](https://github.com/NVIDIA/cutlass/pull/1084) with support for canonical layouts (TN). The implementation supports upcast on operandB {fp16, bf16} x {s8, u8}, and upcast on operandA {s8, u8} x {fp16, bf16}. +* [Copy Async based Hopper GEMMs](./test/unit/gemm/device/sm90_gemm_bf16_bf16_bf16_alignx_tensor_op_f32_warpspecialized_cooperative.cu) - which support lower than 16B aligned input tensors. +* Kernel schedules and Builder support for mixed precision and Copy Async GEMMs with < 16B aligned input tensors. +* Profiler support for lower-aligned Hopper GEMMs. +* Performance Improvements to [Scatter-Gather Hopper Example](./examples/52_hopper_gather_scatter_fusion). +* Sub-Byte type fixes and improvements. +* EVT Support for RELU with Aux bitmap tensor store (used in dRELU). See [SM90 EVT fusions](./include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp) for details. +* Fusion support for backprop fusions including drelu, dgelu, and dbias. +* Support for void-C kernels and SM80 mixed-input GEMMs in the CUTLASS Python interface + +## [3.2.2](https://github.com/NVIDIA/cutlass/releases/tag/v3.2.2) (2023-10-25) +* Minor patch for issue/1138 + +## [3.2.1](https://github.com/NVIDIA/cutlass/releases/tag/v3.2.1) (2023-09-22) +* Python support SM90 Epilogue Visitor Tree (EVT) on top of the C++ support released in 3.2.0. +* SM80 EVT support in C++ and Python. +* Other SM90 epilogue improvements. +* Splitting CUTLASS library into smaller units based on operation, arch and datatypes. See [1105](https://github.com/NVIDIA/cutlass/discussions/1105) for details. +* Making `tools/library/scripts` packageable - `tools/library/scripts` is now moving to `python/cutlass_library`. See the Python [README](./python/README.md) for details. +* SM90 TF32 kernel improvements for all layouts. +* SM90 rasterization direction support in the CUTLASS profiler. +* Improvement for CUTLASS profiler build times. +* Remove Python-C++ bindings. + +## [3.2.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.2.0) (2023-08-03) + +* New warp-specialized persistent FP8 GEMM kernel [kernel schedules](./include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp) and [mainloops](./include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized_fp8.hpp) targeting Hopper architecture that achieve great performance with TMA, WGMMA, and threadblock clusters. An example showcasing [Hopper warp-specialized FP8 GEMMs](./examples/54_hopper_fp8_warp_specialized_gemm). FP8 GEMMs come with a fast accumulation mode. When enabled, problem execution might be faster but at the cost of lower accuracy because intermediate results will not periodically be promoted to a higher precision. +* New [Epilogue Visitor Tree (EVT)](./examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu) support for Hopper TMA epilogues. EVTs allows for user-defined customized epilogue fusion patterns without having to write a new epilogue. +* [Stream-K](./include/cutlass/gemm/kernel/sm90_tile_scheduler_stream_k.hpp) feature for Hopper. Note that this is only a functional implementation of stream-K, and should not be used for performance comparison. Optimizations are expected in a future release. +* Improved CTA rasterization and support for CTA swizzling for Hopper kernels using the [Tile Scheduler](./include/cutlass/gemm/kernel/sm90_tile_scheduler.hpp). +* Improved performance for [warp-specialized TensorFloat-32 (TF32) GEMM kernels](test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32_gmma_rs_cluster_warpspecialized.cu) targeting Hopper TMA. +* [Hopper GEMM+Permute](./examples/53_hopper_gemm_permute/53_hopper_gemm_permute.cu), an example of fusing tensor reordering (permutation) with GEMM mainloop or epilogue. +* New CUTLASS 2D Convolution Python interface. New [example](./examples/python/03_basic_conv2d.ipynb) here. +* Support for Windows (MSVC) builds. Tested with Visual Studio 2019 v16.11.27 on Windows 10.0. +* Optimal performance using [**CUDA 12.2u1**](https://developer.nvidia.com/cuda-downloads) +* Updates and bugfixes from the community (thanks!) + +## [3.1.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.1.0) (2023-04-14) +* New CUTLASS Python interface that aims to provide an ease-of-use interface for instantiating, emitting, compiling, and running CUTLASS kernels via Python. More details [here](./python/README.md) and new [examples](./examples/python). +* New [efficient epilogues](test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative.cu#L783) using TMA for Hopper. +* Support for [fused epilogues](test/unit/gemm/device/sm90_gemm_f16_f16_f16_tensor_op_f32_cluster_warpspecialized_cooperative_bias_elementwise.cu), such Bias, ReLU and GELU, using the new efficient epilogues. +* New [warp-specialized TensorFloat-32 (TF32) GEMM kernels](test/unit/gemm/device/sm90_gemm_tf32_tf32_f32_tensor_op_f32_gmma_rs_cluster_warpspecialized.cu) targeting Hopper TMA. +* New [*warp-specialized persistent cooperative*](./include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_cooperative.hpp) kernel design that allows for larger tile sizes and improves performance on Hopper. +* An [example](./examples/51_hopper_gett) showcasing GEMM-Like Tensor-Tensor Contraction (GETT) capability on Hopper. +* Epilogue builders. Similar to mainloop builders (see [example 49](./examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu)), epilogue builders aim to generate the best-possible epilogue while exposing incremental opt-ins for greater customization. +* Profiler support for overriding kernel and epilogue builder auto schedules for 3.x API kernels, allowing specific policies to be run in the CUTLASS profiler. +* Performance optimizations for the [*warp-specialized persistent ping-pong*](./include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized_pingpong.hpp) kernel. +* Changes to the [GEMM API 3.x](./media/docs/gemm_api_3x.md), involving the host-facing arguments and the underlying `Params` structs. +* [FMHA Backward Pass](./examples/41_fused_multi_head_attention/fused_multi_head_attention_backward.cu) from Meta xFormers. +* [Streamk GEMM with Broadcast](./examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu) enables epilogue broadcast with StreamK GEMM. +* [Batched B2B GEMM](./examples/13_two_tensor_op_fusion) now can run multiple Back-to-Back GEMM with the same problem size in parallel. +* [Batched Strided GEMV](test/unit/gemm/device/gemv.cu) support both row major and column major input matrix. +* [Permute + GEMM fusion](./examples/39_gemm_permute) can fuse Permute with following GEMM now. Before, we only support fusing GEMM with Permute in the epilogue. +* [Row Broadcast](./include/cutlass/epilogue/threadblock/predicated_tile_iterator_row_broadcast.h) can be fused in the epilogue. +* The GitHub branch is renamed from `master` to `main` in this release. +* Optimal performance using [**CUDA 12.1**](https://developer.nvidia.com/cuda-downloads) +* Updates and bugfixes from the community (thanks!) + +## [3.0.0](https://github.com/NVIDIA/cutlass/releases/tag/v3.0.0) (2023-01-23) +* [CuTe](./media/docs/cute/00_quickstart.md), a [new core library and backend](./include/cute) for CUTLASS 3.0 that defines a single Layout vocabulary type and an associated algebra of layouts for a much more expressive and composable abstraction for tensors, sets of parallel agents, and operations by said agents on tensors. +* [A new conceptual operation hierarchy](./media/docs/cutlass_3x_design.md) that replaces the architecture-centric hierarchy of CUTLASS 2.x and [documentation for CUTLASS 3.0's GEMM API changes](./media/docs/gemm_api_3x.md). +* Strict API backwards compatibility that exposes both 2.x and 3.x API kernels through the same [`device::GemmUniversalAdapter`](./include/cutlass/gemm/device/gemm_universal_adapter.h) and [`kernel::GemmUniversal`](./include/cutlass/gemm/kernel/gemm_universal.hpp) types, allowing users to include both APIs in the same translation units. More information can be found in the [3.x backwards compatibility section](./media/docs/cutlass_3x_backwards_compatibility.md). +* Updates to [Functionality](./media/docs/functionality.md) which directs users on which kernels are supported via CUTLASS-2 and CUTLASS-3. +* Updates to [Compatibility](./README.md#compatibility) Section regarding supported compilers, operating systems, CUDA Toolkits, Hardware Architectures and [Target Architecture](./README.md#Target-Architecture). +* New warp-specialized GEMM [kernel schedules](./include/cutlass/gemm/kernel/sm90_gemm_tma_warpspecialized.hpp) and [mainloops](./include/cutlass/gemm/collective/sm90_mma_tma_gmma_ss_warpspecialized.hpp) targeting Hopper architecture that achieve great performance with TMA, WGMMA, and threadblock clusters. +* Extensions to CUTLASS profiler to support threadblock cluster shapes in library and profiler tile configurations. +* [CUTLASS library integration](./tools/library/src/gemm_operation_3x.hpp) for 3.x API kernels built through the new `CollectiveBuilder` API, enabling CUTLASS profiler. +* Support for [Hopper GEMMs](./examples/48_hopper_warp_specialized_gemm) through the new 3.0 API with CuTe-based exposure of the Hopper [Tensor Memory Accelerator](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#data-movement-and-conversion-instructions-cp-async-bulk-tensor) and [WGMMA Tensor Core](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions) features. +* Set of examples that demonstrate the usage of the new 3.0 API to easily build GEMM kernels targeting Hopper: examples [48](./examples/48_hopper_warp_specialized_gemm), [49](./examples/49_hopper_gemm_schedules_with_collective_builder), and [50](./examples/50_hopper_gemm_with_epilogue_swizzle). + +## [2.11.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.11.0) (2022-11-19) +* [Stream-K](./examples/47_ampere_gemm_universal_streamk), which is a new general way to do split-K. It can not only improve performance, but can also significantly reduce the number of tile sizes that need to be profiled to find the best one. +* [Fused multi-head attention Kernel](./examples/41_fused_multi_head_attention). It has two variants: one uses batched GEMM for the fixed sequence length, and the other one uses group GEMM for the variable sequence length. Both versions just need one kernel. +* [Dual GEMM](./examples/45_dual_gemm), which can fuse A x B and A x C into one kernel. Two GEMMs has no producer-consumer dependency. +* Hopper improves [double precision matrix multiplication](./test/unit/gemm/device/gemm_f64n_f64t_f64t_tensor_op_f64_sm90.cu) by 2x compared to Ampere at iso-clocks. It is supported since CUDA 11.8. +* [BLAS3](./test/unit/gemm/device/hemm_cf64_cf64_cf64_tensor_op_f64_sm90.cu) functions with Hoppers new double precision matrix multiplication instructions. +* [ELL Block Sparse GEMM](./examples/43_ell_block_sparse_gemm), which uses an [ELL matrix](https://developer.nvidia.com/blog/accelerating-matrix-multiplication-with-block-sparse-format-and-nvidia-tensor-cores/) to describe the sparsity of A matrix. B and output matrices are still dense. The block size can be arbitary. +* Optimized [Group Conv](./examples/42_ampere_tensorop_group_conv) for SingleGroup mode, which requires that the output channel per group is a multiple of Threadblock tile N. +* [Optimized DepthWise Conv](./examples/46_depthwise_simt_conv2dfprop/depthwise_simt_conv2dfprop.cu). Two new modes are added + * [kOptimized](./test/unit/conv/device/depthwise_conv2d_fprop_direct_conv_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu) - use direct conv to compute instead of implicit GEMM. + * The restrictions are: 1) input ,output channel and group number should be multiple of (128 / sizeof(input element)). 2) The input filter size should be the same as the template parameter configuration. + * [kFixedStrideDilation](./test/unit/conv/device/depthwise_conv2d_fprop_direct_conv_fixed_stride_dilation_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu) - which puts stride and dilation into templates to further improve the performance. In this mode, kernel persistents some inputs into register to squeeze more performance, so large filter/stride/dilation is not recommanded. + * The restrictions are: 1) input, output channel and group number should be multiple of (128 / sizeof(input element)). 2) input filter size, stride, dilation should same as the template parameter configuration. +* [Scripts](./examples/44_multi_gemm_ir_and_codegen) to fuse multiple back-to-back GEMM. Its implementation was discussed in a GTC'22 Spring [talk](https://www.nvidia.com/en-us/on-demand/session/gtcspring22-s41606/). +* [FP8 data type definition](./include/cutlass/float8.h) and [conversion routines](./include/cutlass/numeric_conversion.h#L1274-2115). +* Updates and bugfixes from the community (thanks!). Big shout out to Meta's [xFormers](https://github.com/facebookresearch/xformers). + +* **Deprecation announcement:** CUTLASS plans to deprecate the following: + * Maxwell and Pascal GPU architectures + * Ubuntu 16.04 + * CUDA 10.2 + +## [2.10.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.10.0) (2022-08-23) +* [CUTLASS Python](./examples/40_cutlass_py) now supports GEMM, CONV, Group GEMM for different data types as well as different epilogue flavours. +* Optimizations for CUTLASS's [Grouped GEMM](./examples/24_gemm_grouped/gemm_grouped.cu) kernel. Threadblock scheduling part is improved. Some computation can be moved to the host side if applicable. [Grouped Syr2k](./examples/38_syr2k_grouped/syr2k_grouped.cu) kernels are added, too. +* Optimizations for [GEMM+Softmax](./examples/35_gemm_softmax). All the reduction computation is fused into the previous GEMM. More template arguments are provided to fine tune the performance. +* [Grouped GEMM for Multihead Attention](./examples/41_multi_head_attention). This general group gemm based MHA does not require the sequence length of all GEMMs to be the same which makes it most useful for natural language processing. +* [GEMM + Layer norm fusion for Ampere](./examples/37_gemm_layernorm_gemm_fusion/) splits the layernorm into two parts and both of them can be fused into the GEMMs before and after separately. In addition to use square sum to compute variance of layernorm, [Shift-K](https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Computing_shifted_data) is provided if square sum raise numerical issues. +* [GEMM Epilogue Permutation Fusion](./examples/39_gemm_permute) can apply user provided permutation layout mapping in the GEMM epilogue. +* [Grouped convolution targeting implicit GEMM](test/unit/conv/device/group_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu) introduces the first group convolution implementation to CUTLASS. It is an Analytical implementation, not an Optimized. The restrictions are: 1) input and output channel number should be multiple of group number. 2) split-K is not supported. The implementation has 2 modes: + * kSingleGroup: output channel per group is multiple of Threadblock tile N. + * kMultipleGroup: Threadblock tile N is multiple of output channel per group. +* [Depthwise separable convolution](test/unit/conv/device/depthwise_conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_simt_f16_sm60.cu) introduces the first depthwise convolution which is also Analytical for now. The restrictions are: 1) SIMT only 2) No split-K 3) input channel equals to output channel equals to group number. +* Standalone [Layernorm](./tools/util/include/cutlass/util/device_layernorm.h) and [Pooling](./tools/util/include/cutlass/util/device_nhwc_pooling.h) kernels. +* [Back-to-back GEMM/CONV](./examples/13_two_tensor_op_fusion) relaxes the requirement that the first GEMM K dimension needs to be the multiple of Threadblock Tile K dimension. +* Optimal performance using [**CUDA 11.6u2**](https://developer.nvidia.com/cuda-downloads) +* Updates and bugfixes from the community (thanks!) ## [2.9.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.9.0) (2022-04-21) -* [First layer Convolution kernels](/test/unit/conv/device/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu) specialized for small channel counts and reduced alignment - * [Few channels](/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_few_channels.h) specialization for reduced alignment capabilities - * [Fixed channels](/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_fixed_channels.h) further specialized when channel count perfectly matches the access vector size - * [Unit tests](/test/unit/conv/device/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu) - * [Python-based instance emitter](/tools/library/scripts/generator.py) in the CUTLASS Library and support in the Profiler +* [First layer Convolution kernels](./test/unit/conv/device/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu) specialized for small channel counts and reduced alignment + * [Few channels](./include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_few_channels.h) specialization for reduced alignment capabilities + * [Fixed channels](./include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_fixed_channels.h) further specialized when channel count perfectly matches the access vector size + * [Unit tests](./test/unit/conv/device/conv2d_fprop_few_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu) + * [Python-based instance emitter](./python/cutlass_library/generator.py) in the CUTLASS Library and support in the Profiler * [BLAS3](https://docs.nvidia.com/cuda/cublas/index.html#cublas-level-3-function-reference) operators accelerated by Tensor Cores * Supported types: f32, cf32, f64, cf64, tf32x3, complex tf32x3 - * [HERK](/test/unit/gemm/device/her2k_cf32h_cf32n_tensor_op_fast_f32_sm80.cu) with [emitter](/tools/library/scripts/rank_k_operation.py) - * [SYRK](/test/unit/gemm/device/syrk_f32n_f32t_tensor_op_fast_f32_sm80.cu) with [emitter](/tools/library/scripts/rank_k_operation.py) - * [SYMM](/test/unit/gemm/device/symm_f32n_f32n_tensor_op_fast_f32_ls_sm80.cu) with [emitter](/tools/library/scripts/symm_operation.py) - * [TRMM](/test/unit/gemm/device/trmm_f32n_f32t_f32t_tensor_op_fast_f32_ls_sm80.cu) with [emitter](/tools/library/scripts/trmm_operation.py) - * [Unit tests](/test/unit/gemm/device/testbed_rank_k_universal.h) -* [CUTLASS Python](/examples/40_cutlass_py) demonstrating JIT compilation of CUTLASS kernels and a Python-based runtime using [CUDA Python](https://developer.nvidia.com/cuda-python) - * [Python-based runtime](/tools/library/scripts/rt.py) interoperable with existing emitters -* [GEMM + Softmax example](/examples/35_gemm_softmax) -* [Gather and Scatter Fusion with GEMM](/examples/36_gather_scatter_fusion) can gather inputs and scatters outputs based on indices vectors in the same GEMM kernel. + * [HERK](./test/unit/gemm/device/her2k_cf32h_cf32n_tensor_op_fast_f32_sm80.cu) with [emitter](./python/cutlass_library/rank_k_operation.py) + * [SYRK](./test/unit/gemm/device/syrk_f32n_f32t_tensor_op_fast_f32_sm80.cu) with [emitter](./python/cutlass_library/rank_k_operation.py) + * [SYMM](./test/unit/gemm/device/symm_f32n_f32n_tensor_op_fast_f32_ls_sm80.cu) with [emitter](./python/cutlass_library/symm_operation.py) + * [TRMM](./test/unit/gemm/device/trmm_f32n_f32t_f32t_tensor_op_fast_f32_ls_sm80.cu) with [emitter](./python/cutlass_library/trmm_operation.py) + * [Unit tests](./test/unit/gemm/device/testbed_rank_k_universal.h) +* [CUTLASS Python](./examples/40_cutlass_py) demonstrating JIT compilation of CUTLASS kernels and a Python-based runtime using [CUDA Python](https://developer.nvidia.com/cuda-python) + * [Python-based runtime](./tools/library/scripts/rt.py) interoperable with existing emitters +* [GEMM + Softmax example](./examples/35_gemm_softmax) +* [Gather and Scatter Fusion with GEMM](./examples/36_gather_scatter_fusion) can gather inputs and scatters outputs based on indices vectors in the same GEMM kernel. * It can select random rows in a row major matrix. * It can select random columns in a column major matrix. -* [Back-to-back GEMM/CONV](examples/13_two_tensor_op_fusion) fully supports buffering the first GEMM/CONV results in the shared memory for the latter one to use. It can eliminate register spill when the tile size is big. Additionally, bias vector add is supported in the first GEMM/CONV. +* [Back-to-back GEMM/CONV](./examples/13_two_tensor_op_fusion) fully supports buffering the first GEMM/CONV results in the shared memory for the latter one to use. It can eliminate register spill when the tile size is big. Additionally, bias vector add is supported in the first GEMM/CONV. * Supported kernels: GEMM and CONV. * Supported types: fp16 and int8. * Supported architectures: Turing and Ampere. -* [Transposed Convolution](/examples/34_transposed_conv2d) (a.k.a Deconvolution) support which reuses Dgrad implementation. -* [Utility functions](/tools/util/include/cutlass/util) that can pad NHWC and convert between NCHW and NHWC. +* [Transposed Convolution](./examples/34_transposed_conv2d) (a.k.a Deconvolution) support which reuses Dgrad implementation. +* [Utility functions](./tools/util/include/cutlass/util) that can pad NHWC and convert between NCHW and NHWC. * [Small alignment implicit gemm](https://github.com/NVIDIA/cutlass/issues/242) support for Fprop/Dgrad/Wgrad so that padding is no longer mandated to use tensor cores in these kernels. * Epilogue enhancement: * Eliminate bank conflicts in int8 tensor core kernels. * Half2 usage if epilogue compute type is fp16. * More activation functions: Silu, Hardswish, Leaky Relu. - * New elementwise fusion pattern for [residual block](/include/cutlass/epilogue/thread/linear_combination_residual_block.h). -* [Group GEMM](/examples/24_gemm_grouped) thread block number calculation fix which helps to launch the intended number of threadblocks to fully occupy the GPUs. + * New elementwise fusion pattern for [residual block](./include/cutlass/epilogue/thread/linear_combination_residual_block.h). +* [Group GEMM](./examples/24_gemm_grouped) thread block number calculation fix which helps to launch the intended number of threadblocks to fully occupy the GPUs. * [Parallel GEMM splitk](https://github.com/NVIDIA/cutlass/pull/277) support in the CUTLASS profiler. -* Optimal performance using [**CUDA 11.7**](https://developer.nvidia.com/cuda-downloads) +* Optimal performance using [**CUDA 11.6u2**](https://developer.nvidia.com/cuda-downloads) * Updates and bugfixes from the community (thanks!) + ## [2.8.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.8.0) (2021-11-19) * **TF32x3:** emulated single-precision using Tensor Cores * 45+ TFLOPs on NVIDIA A100 - * [GEMM SDK example](/examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm/27_ampere_3xtf32_fast_accurate_tensorop_gemm.cu) (real) - * [COMPLEX GEMM SDK example](/examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm.cu) (complex) - * [Implicit GEMM Convolution SDK example](/examples/28_ampere_3xtf32_fast_accurate_tensorop_fprop/ampere_3xtf32_fast_accurate_tensorop_fprop.cu) + * [GEMM SDK example](./examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm/27_ampere_3xtf32_fast_accurate_tensorop_gemm.cu) (real) + * [COMPLEX GEMM SDK example](./examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/29_3xtf32_complex_gemm.cu) (complex) + * [Implicit GEMM Convolution SDK example](./examples/28_ampere_3xtf32_fast_accurate_tensorop_fprop/ampere_3xtf32_fast_accurate_tensorop_fprop.cu) * **Mainloop fusion for Convolution:** convolution with fused per-channel scale-bias-relu - * [Conv Fprop SDK example](/examples/25_ampere_fprop_mainloop_fusion/ampere_fprop_mainloop_fusion.cu) - * [Conv WGrad SDK example](/examples/26_ampere_wgrad_mainloop_fusion/ampere_wgrad_mainloop_fusion.cu) - * [cutlass::conv::device::ImplicitGemmConvolutionFusion](/include/cutlass/conv/device/implicit_gemm_convolution_fusion.h) + * [Conv Fprop SDK example](./examples/25_ampere_fprop_mainloop_fusion/ampere_fprop_mainloop_fusion.cu) + * [Conv WGrad SDK example](./examples/26_ampere_wgrad_mainloop_fusion/ampere_wgrad_mainloop_fusion.cu) + * [cutlass::conv::device::ImplicitGemmConvolutionFusion](./include/cutlass/conv/device/implicit_gemm_convolution_fusion.h) * **Grouped GEMM:** similar to batched GEMM with distinct problem size per group - * [SDK example](/examples/24_gemm_grouped) with performance comparison with Batched Strided GEMM - * [cutlass::gemm::device::GemmGrouped](/include/cutlass/gemm/device/gemm_grouped.h) -* [Implicit GEMM Convolution fusion](/examples/13_two_tensor_op_fusion/) supports staging 1st convolution's output accumulator in the shared memory on Turing. This allows more flexible warp tile sizes and less regsiter pressue. + * [SDK example](./examples/24_gemm_grouped) with performance comparison with Batched Strided GEMM + * [cutlass::gemm::device::GemmGrouped](./include/cutlass/gemm/device/gemm_grouped.h) +* [Implicit GEMM Convolution fusion](./examples/13_two_tensor_op_fusion/) supports staging 1st convolution's output accumulator in the shared memory on Turing. This allows more flexible warp tile sizes and less regsiter pressue. * Optimal performance using [**CUDA 11.5**](https://developer.nvidia.com/cuda-downloads) * Updates from the community (thanks!) @@ -61,11 +274,11 @@ * CUDA 10.2 ## [2.7.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.7.0) (2021-09-24) - * Mainloop fusion for GEMM: [summation over A or B](/examples/23_ampere_gemm_operand_reduction_fusion/ampere_gemm_operand_reduction_fusion.cu) - * [Strided DGRAD (optimized iterators)](/include/cutlass/conv/kernel/default_conv2d_dgrad.h) - * [Half-precision GELU_taylor activation functions](/include/cutlass/epilogue/thread/activation.h#L196) + * Mainloop fusion for GEMM: [summation over A or B](./examples/23_ampere_gemm_operand_reduction_fusion/ampere_gemm_operand_reduction_fusion.cu) + * [Strided DGRAD (optimized iterators)](./include/cutlass/conv/kernel/default_conv2d_dgrad.h) + * [Half-precision GELU_taylor activation functions](./include/cutlass/epilogue/thread/activation.h#L196) * Use these when accumulation and epilogue compute types are all `cutlass::half_t` - * Tuning and bug fixes to [fused GEMM + GEMM example](/examples/13_two_tensor_op_fusion/) + * Tuning and bug fixes to [fused GEMM + GEMM example](./examples/13_two_tensor_op_fusion/) * Support for smaller than 128b aligned Convolutions: [see examples](test/unit/conv/device/conv2d_fprop_implicit_gemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f16_sm80.cu#L272) * Caching of results to accelerate Convolution [unit tests](test/unit/conv/device/cache_testbed_output.h) * Can be enabled or disabled by running `cmake .. -DCUTLASS_TEST_ENABLE_CACHED_RESULTS=OFF` @@ -80,27 +293,27 @@ ## [2.6.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.6.0) (2021-07-22) * Optimal performance when compiled with the [CUDA 11.4 Toolkit](https://developer.nvidia.com/cuda-toolkit) - * Adopt the new L2 prefetch feature in [cp.async](/include/cutlass/arch/memory.h) and [global load](/include/cutlass/arch/memory_sm80.h) + * Adopt the new L2 prefetch feature in [cp.async](./include/cutlass/arch/memory.h) and [global load](./include/cutlass/arch/memory_sm80.h) * Fused operators with GEMM and Convolution * [Fused broadcast in epilogue](test/unit/gemm/device/gemm_with_broadcast_f16n_f16n_f16n_tensorop_f32_sm75.cu) - * [Fused partial reduction in epilogue](/test/unit/gemm/device/gemm_with_reduction_f16n_f16n_f16n_tensorop_f32_sm75.cu) + * [Fused partial reduction in epilogue](./test/unit/gemm/device/gemm_with_reduction_f16n_f16n_f16n_tensorop_f32_sm75.cu) * 64b tensor strides and leading dimensions support for GEMMs - * Affine rank=2 matrix layouts - * Row stride and column stride for matrices using [cutlass::layout::AffineRank2](/include/cutlass/layout/matrix.h) - * Support [FP64 tensor core](/examples/18_ampere_fp64_tensorop_affine2_gemm/ampere_fp64_tensorop_affine2_gemm.cu) and SIMT GEMM. - * [Batched GEMV](/test/unit/gemm/device/gemv.cu) preview implementation + * Affine rank=2 matrix layouts + * Row stride and column stride for matrices using [cutlass::layout::AffineRank2](./include/cutlass/layout/matrix.h) + * Support [FP64 tensor core](./examples/18_ampere_fp64_tensorop_affine2_gemm/ampere_fp64_tensorop_affine2_gemm.cu) and SIMT GEMM. + * [Batched GEMV](./test/unit/gemm/device/gemv.cu) preview implementation * [New strided Dgrad](test/unit/conv/device/conv2d_strided_dgrad_implicit_gemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_sm80.cu) implementation * Accelerates over previous implementation by cutting down redundant math by 4x * Support using new `Dy` and `w` analytic iterators and existing `cutlass::conv::device::ImplicitGemmConvolution` interface * Quaternion-valued GEMM and Convolution in single- and double-precision (targeting CUDA Cores) - * Updates to [quaternion.h](/include/cutlass/quaternion.h) and [functional.h](/include/cutlass/functional.h) - * SDK Example for [GEMM](/examples/21_quaternion_gemm/quaternion_gemm.cu) and [Convolution](/examples/22_quaternion_gemm/quaternion_conv.cu) - * [Unit tests for GEMM](/test/unit/gemm/device/simt_qgemm_nn_sm50.cu) and [Convolution](/test/unit/conv/device/conv2d_fprop_implicit_gemm_qf32nhwc_qf32nhwc_qf32nhwc_simt_f32_sm50.cu) + * Updates to [quaternion.h](./include/cutlass/quaternion.h) and [functional.h](./include/cutlass/functional.h) + * SDK Example for [GEMM](./examples/21_quaternion_gemm/quaternion_gemm.cu) and [Convolution](./examples/22_quaternion_conv/quaternion_conv.cu) + * [Unit tests for GEMM](./test/unit/gemm/device/simt_qgemm_nn_sm50.cu) and [Convolution](./test/unit/conv/device/conv2d_fprop_implicit_gemm_qf32nhwc_qf32nhwc_qf32nhwc_simt_f32_sm50.cu) * Many improvements to the epilogue. - * Provide an [option](/include/cutlass/epilogue/threadblock/epilogue.h) to not fully unroll the epilogue to reduce the code size and improve the performance when using complicated elementwise operations + * Provide an [option](./include/cutlass/epilogue/threadblock/epilogue.h) to not fully unroll the epilogue to reduce the code size and improve the performance when using complicated elementwise operations * Performance improvement for FP16 tensor core kernels * Bug fixes - * Enhanced Clang support and the combination of Clang 13 and CUDA 11.4 can build and run kernels from Pascal and Ampere. + * Enhanced Clang support and the combination of Clang 13 and CUDA 11.4 can build and run kernels from Pascal and Ampere. * Updated minimum CUDA Toolkit requirement to 10.2 * [CUDA 11.4 Toolkit](https://developer.nvidia.com/cuda-toolkit) recommended * Corrections and bug fixes reported by the CUTLASS community @@ -109,17 +322,17 @@ ## [2.5.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.5.0) (2021-02-26) * Tensor reductions * _m_-to-_n_ reductions of tensors with affine layout - * [Specializations](/test/unit/reduction/device/tensor_reduce_contiguous.cu) for reductions including contiguous dimension - * [Specializations](/test/unit/reduction/device/tensor_reduce_strided.cu) for reductions excluding contiguous dimension + * [Specializations](./test/unit/reduction/device/tensor_reduce_contiguous.cu) for reductions including contiguous dimension + * [Specializations](./test/unit/reduction/device/tensor_reduce_strided.cu) for reductions excluding contiguous dimension * Custom reduction functors such as `cutlass::logical_and` * Large tensor support, up to 2^63 elements (however, each dimension is limited to an extent of 2^31) * Optimizations for 3-D convolution - * [Optimized tile iterators](include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h) using precomputed delta table for 3-D convolution + * [Optimized tile iterators](./include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h) using precomputed delta table for 3-D convolution * Full coverage of [forward](test/unit/conv/device/conv3d_fprop_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm80.cu) and [backwards](test/unit/conv/device/conv3d_dgrad_implicit_gemm_f16ndhwc_f16ndhwc_f32ndhwc_tensor_op_f32_sm80.cu) passes for 3D convolution - * [Fused Convolution+Convolution example](/examples/13_two_tensor_op_fusion/README.md) + * [Fused Convolution+Convolution example](./examples/13_two_tensor_op_fusion/README.md) * Corrections and bug fixes reported by the CUTLASS community * Thank you for filing these issues! - + ## [2.4.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.4.0) (2020-11-19) * Implicit GEMM convolution kernels supporting CUDA and Tensor Cores on NVIDIA GPUs @@ -127,11 +340,11 @@ * Data type: FP32, complex, Tensor Float 32 (TF32), BFloat16 (BF16), Float16, Int4, Int8, Int32 * Spatial dimensions: 1-D, 2-D, and 3-D * Layout: NHWC, NCxHWx - * Implicit GEMM convolution components: + * Implicit GEMM convolution components: * Global memory iterators supporting Fprop, Dgrad, and Wgrad * `MmaMultistage` for implicit GEMM convolution for NVIDIA Ampere architecture * `MmaPipeline` for implicit GEMM convolution for NVIDIA Volta and Turing architectures - * [Documentation](/media/docs/implicit_gemm_convolution.md) describing Implicit GEMM Convolution algorithm and implementation + * [Documentation](./media/docs/implicit_gemm_convolution.md) describing Implicit GEMM Convolution algorithm and implementation ## [2.3.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.3.0) (2020-09-23) * [NVIDIA Ampere Architecture features](https://devblogs.nvidia.com/nvidia-ampere-architecture-in-depth/) @@ -139,21 +352,21 @@ * Direct access to Sparse Tensor Cores and maximum performance via [`mma.sp.sync`](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma-and-friends) * Fast SGEMM targeting GeForce RTX 30-series CUDA Cores * Minor Features: - * [Activation functions](/include/cutlass/epilogue/thread/activation.h) such as [GeLU](/include/cutlass/epilogue/thread/linear_combination_gelu.h) and [Sigmoid](/include/cutlass/epilogue/thread/linear_combination_sigmoid.h) - * Small [matrix](/include/cutlass/matrix.h) and [quaternion](/include/cutlass/quaternion.h) template classes in device code - * [Floating-point constants](/include/cutlass/constants.h) + * [Activation functions](./include/cutlass/epilogue/thread/activation.h) such as [GeLU](./include/cutlass/epilogue/thread/linear_combination_gelu.h) and [Sigmoid](./include/cutlass/epilogue/thread/linear_combination_sigmoid.h) + * Small [matrix](./include/cutlass/matrix.h) and [quaternion](./include/cutlass/quaternion.h) template classes in device code + * [Floating-point constants](./include/cutlass/constants.h) * NVIDIA Ampere GPU Architecture examples and documentation: - * [Tensor Float 32](/examples/14_ampere_tf32_tensorop_gemm/ampere_tf32_tensorop_gemm.cu) and - * [Sparse Tensor Cores](/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm.cu) - * Documentation added on CUTLASS [efficient row-major epilogue](/media/docs/gemm_api.md#efficient-epilogue) + * [Tensor Float 32](./examples/14_ampere_tf32_tensorop_gemm/ampere_tf32_tensorop_gemm.cu) and + * [Sparse Tensor Cores](./examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm.cu) + * Documentation added on CUTLASS [efficient row-major epilogue](./media/docs/gemm_api.md#efficient-epilogue) ## [2.2.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.2.0) (2020-06-08) * [NVIDIA Ampere Architecture features](https://devblogs.nvidia.com/nvidia-ampere-architecture-in-depth/) - * Fast Tensor Core operations: + * Fast Tensor Core operations: * Maximum performance via [`mma.sync`](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma-and-friends) * Tensor Float 32, BFloat16, and double-precision data types * Mixed integer data types (int8, int4, bin1) - * Asynchronous copy for deep software pipelines via [`cp.async`](https://docs.nvidia.com/cuda/parallel-thread-execution) + * Asynchronous copy for deep software pipelines via [`cp.async`](https://docs.nvidia.com/cuda/parallel-thread-execution) * Described in [GTC 2020 Webinar (SR 21745)](https://developer.nvidia.com/gtc/2020/video/s21745) (free registration required) * Features: * SDK examples showing GEMM fused with bias+relu and fused GEMM+GEMM @@ -165,11 +378,11 @@ * Disabled F16C by default for compatibility - enable on cmake command line with `-DCUTLASS_ENABLE_F16C=ON` ## [2.1.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.1.0) (2020-04-06) - * BLAS-style host-side API added to [CUTLASS Library](/media/docs/quickstart.md#cutlass-library) + * BLAS-style host-side API added to [CUTLASS Library](./media/docs/quickstart.md#cutlass-library) * API to launch compiled kernel instances for GEMM and planar complex GEMM * Planar Complex GEMM kernels targeting Volta and Turing Tensor Cores * Computes complex matrix products on matrices stored as disjoint real and imaginary parts - * [SDK Examples of Planar Complex GEMMs](/examples/10_planar_complex/planar_complex.cu) + * [SDK Examples of Planar Complex GEMMs](./examples/10_planar_complex/planar_complex.cu) * Minor enhancements and bug fixes ## [2.0.0](https://github.com/NVIDIA/cutlass/releases/tag/v2.0.0) (2019-11-19) @@ -179,10 +392,10 @@ * Encapsulated functionality embodying modern C++11 programming techniques * Optimized containers and data types for efficient, generic, portable device code * Updates to: - * [Quick start guide](/media/docs/quickstart.md) - * [Documentation](/README.md#documentation) - * [Utilities](/media/docs/utilities.md) - * [CUTLASS Profiler](/media/docs/profiler.md) + * [Quick start guide](./media/docs/quickstart.md) + * [Documentation](./README.md#documentation) + * [Utilities](./media/docs/utilities.md) + * [CUTLASS Profiler](./media/docs/profiler.md) * Native Turing Tensor Cores * Efficient GEMM kernels targeting Turing Tensor Cores * Mixed-precision floating point, 8-bit integer, 4-bit integer, and binarized operands @@ -246,7 +459,7 @@ ## Copyright -Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. SPDX-License-Identifier: BSD-3-Clause ``` diff --git a/CITATION.cff b/CITATION.cff index ea053e66fb..ea97f1f68e 100644 --- a/CITATION.cff +++ b/CITATION.cff @@ -5,33 +5,61 @@ message: >- following metadata. type: software authors: - - given-names: Andrew - email: akerr@nvidia.com - family-names: Kerr + - given-names: Vijay + family-names: Thakkar + email: vithakkar@nvidia.com + affiliation: NVIDIA + - given-names: Pradeep + family-names: Ramani + email: prramani@nvidia.com + affiliation: NVIDIA + - given-names: Cris + family-names: Cecka + email: ccecka@nvidia.com + affiliation: NVIDIA + - given-names: Aniket + family-names: Shivam + email: ashivam@nvidia.com + affiliation: NVIDIA + - given-names: Honghao + family-names: Lu + email: honghaol@nvidia.com + affiliation: NVIDIA + - given-names: Ethan + family-names: Yan + email: etyan@nvidia.com + affiliation: NVIDIA + - given-names: Jack + family-names: Kosaian + email: jkosaian@nvidia.com + affiliation: NVIDIA + - given-names: Mark + family-names: Hoemmen + email: mhoemmen@nvidia.com affiliation: NVIDIA - given-names: Haicheng family-names: Wu - affiliation: NVIDIA email: haichengw@nvidia.com - - given-names: Manish - family-names: Gupta - affiliation: Google - email: manigupta@google.com - - given-names: Dustyn - family-names: Blasig - email: dblasig@nvidia.com affiliation: NVIDIA - - given-names: Pradeep - family-names: Ramini - email: prramani@nvidia.com + - given-names: Andrew + family-names: Kerr + email: akerr@nvidia.com + affiliation: NVIDIA + - given-names: Matt + family-names: Nicely + email: mnicely@nvidia.com affiliation: NVIDIA - given-names: Duane family-names: Merrill email: dumerrill@nvidia.com affiliation: NVIDIA - - given-names: Aniket - family-names: Shivam - email: ashivam@nvidia.com + - given-names: Dustyn + family-names: Blasig + email: dblasig@nvidia.com + affiliation: NVIDIA + - given-names: Fengqi + family-names: Qiao + email: fqiao@nvidia.com affiliation: NVIDIA - given-names: Piotr family-names: Majcher @@ -49,10 +77,12 @@ authors: family-names: Wang email: jinw@nvidia.com affiliation: NVIDIA - - given-names: Matt - family-names: Nicely - email: mnicely@nvidia.com - affiliation: NVIDIA + - given-names: Manish + family-names: Gupta + affiliation: Google + email: manigupta@google.com + + repository-code: 'https://github.com/NVIDIA/cutlass' abstract: >- CUTLASS is a collection of CUDA C++ template @@ -71,12 +101,12 @@ abstract: >- flexibility simplifies their use as building blocks within custom kernels and applications. keywords: - - 'cutlass, tensor cores, cuda' + - 'cutlass, tensor cores, cuda, cute, nvidia, gpu, linear algebra, matrix computations' license: BSD-3-Clause -license-url: https://github.com/NVIDIA/cutlass/blob/v2.9.0/LICENSE.txt -version: '2.9' -date-released: '2022-04-27' +license-url: https://github.com/NVIDIA/cutlass/blob/v3.0.0/LICENSE.txt +version: '3.0.0' +date-released: '2023-01-23' identifiers: - type: url - value: "https://github.com/NVIDIA/cutlass/tree/v2.9.0" - description: The GitHub release URL of tag 2.9.0 + value: "https://github.com/NVIDIA/cutlass/tree/v3.0.0" + description: The GitHub release URL of tag 3.0.0 diff --git a/CMakeLists.txt b/CMakeLists.txt index cfed600b72..e9c501bc2b 100755 --- a/CMakeLists.txt +++ b/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without @@ -26,7 +26,8 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -cmake_minimum_required(VERSION 3.12.4 FATAL_ERROR) +cmake_minimum_required(VERSION 3.19 FATAL_ERROR) +cmake_policy(SET CMP0112 NEW) if(cutlass_LOADED) # If CUTLASS has been previously fetched and loaded, don't do it again. @@ -37,31 +38,69 @@ else() endif() message(STATUS "CMake Version: ${CMAKE_VERSION}") +set(IMPLICIT_CMAKE_CXX_STANDARD OFF CACHE BOOL "Do not explicitly specify -std=c++17 if set") + +# To reduce duplicate version locations, parse the version out of the +# main versions.h file and reuse it here. + +file(READ ${CMAKE_CURRENT_SOURCE_DIR}/include/cutlass/version.h VERSION_FILE_CONTENTS) +string(REGEX MATCH "#define CUTLASS_MAJOR ([0-9]+)" _CUTLASS_VERSION_MAJOR "${VERSION_FILE_CONTENTS}") +set(_CUTLASS_VERSION_MAJOR ${CMAKE_MATCH_1}) +string(REGEX MATCH "#define CUTLASS_MINOR ([0-9]+)" _CUTLASS_VERSION_MINOR "${VERSION_FILE_CONTENTS}") +set(_CUTLASS_VERSION_MINOR ${CMAKE_MATCH_1}) +string(REGEX MATCH "#define CUTLASS_PATCH ([0-9]+)" _CUTLASS_VERSION_PATCH "${VERSION_FILE_CONTENTS}") +set(_CUTLASS_VERSION_PATCH ${CMAKE_MATCH_1}) + +message(STATUS "CUTLASS ${_CUTLASS_VERSION_MAJOR}.${_CUTLASS_VERSION_MINOR}.${_CUTLASS_VERSION_PATCH}") + +## CUTLASS PROJECT ############################################################# + +project(CUTLASS VERSION ${_CUTLASS_VERSION_MAJOR}.${_CUTLASS_VERSION_MINOR}.${_CUTLASS_VERSION_PATCH} LANGUAGES CXX) + +################################################################################ + +if (CMAKE_CXX_COMPILER_ID MATCHES "GNU") + set(CUTLASS_GNU_HOST_COMPILE ON CACHE BOOL "Using GNU tools for host code compilation") +endif() +if (CMAKE_CXX_COMPILER_ID MATCHES "[Cc]lang") + set(CUTLASS_CLANG_HOST_COMPILE ON CACHE BOOL "Using Clang tools for host code compilation") +endif() +if (CMAKE_CXX_COMPILER_ID MATCHES "MSVC") + set(CUTLASS_MSVC_HOST_COMPILE ON CACHE BOOL "Using MSVC tools for host code compilation") +endif() + +################################################################################ -project(CUTLASS VERSION 2.9.0 LANGUAGES CXX) include(${CMAKE_CURRENT_SOURCE_DIR}/CUDA.cmake) -if (CUDA_VERSION VERSION_LESS 10.2) - message(WARNING "CUTLASS ${CUTLASS_VERSION} requires CUDA 10.2 or higher, and strongly recommends CUDA 11.0 or higher.") -elseif (CUDA_VERSION VERSION_LESS 11.0) - message(WARNING "CUTLASS ${CUTLASS_VERSION} support for CUDA ${CUDA_VERSION} is deprecated, please use CUDA 11.0 or higher.") +if (CUDA_VERSION VERSION_LESS 11.3) + message(WARNING "CUTLASS ${CUTLASS_VERSION} requires CUDA 11.4 or higher, and strongly recommends CUDA 11.8 or higher.") +elseif (CUDA_VERSION VERSION_LESS 11.4) + message(WARNING "CUTLASS ${CUTLASS_VERSION} support for CUDA ${CUDA_VERSION} is deprecated, please use CUDA 11.8 or higher.") +endif() + +if(CUTLASS_GNU_HOST_COMPILE AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.3) + message(FATAL_ERROR "GCC version must be at least 7.3!") endif() +if (CUTLASS_CLANG_DEVICE_COMPILE AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.0) + message(FATAL_ERROR "Clang 7.0+ required for GPU compilation") +endif() find_package(Doxygen QUIET) +################################################################################ + # -# CUTLASS 2.x requires C++11 +# CUTLASS 3.x requires C++17 # -set(CMAKE_CXX_STANDARD 11) +set(CMAKE_CXX_STANDARD 17) set(CMAKE_CXX_STANDARD_REQUIRED ON) set(CMAKE_CXX_EXTENSIONS OFF) -if(CUTLASS_NATIVE_CUDA) - set(CMAKE_CUDA_STANDARD 11) - set(CMAKE_CUDA_STANDARD_REQUIRED ON) -else() - list(APPEND CUTLASS_CUDA_NVCC_FLAGS --std=c++11) -endif() +set(CMAKE_CUDA_STANDARD 17) +set(CMAKE_CUDA_STANDARD_REQUIRED ON) + +list(APPEND CUTLASS_CUDA_NVCC_FLAGS --expt-relaxed-constexpr) if(CMAKE_INSTALL_PREFIX_INITIALIZED_TO_DEFAULT) set(CMAKE_INSTALL_PREFIX install CACHE PATH "Default installation location." FORCE) @@ -69,16 +108,28 @@ endif() message(STATUS "Default Install Location: ${CMAKE_INSTALL_PREFIX}") +set(CUTLASS_TEST_LEVEL "0" CACHE STRING "Level of tests to compile.") +# 0 - Sanity, 1 - Release-Quality, 2 - Exhaustive + +find_package(Python3 3.5 COMPONENTS Interpreter REQUIRED) + +################################################################################ set(CUTLASS_ENABLE_HEADERS_ONLY OFF CACHE BOOL "Enable only the header library") if(CUTLASS_ENABLE_HEADERS_ONLY) set(CUTLASS_ENABLE_EXAMPLES_INIT OFF) set(CUTLASS_ENABLE_TOOLS_INIT ON) set(CUTLASS_ENABLE_LIBRARY_INIT OFF) + set(CUTLASS_ENABLE_TESTS_INIT OFF) else() set(CUTLASS_ENABLE_EXAMPLES_INIT ON) set(CUTLASS_ENABLE_TOOLS_INIT ON) set(CUTLASS_ENABLE_LIBRARY_INIT ON) + if(${CMAKE_PROJECT_NAME} STREQUAL ${PROJECT_NAME}) + set(CUTLASS_ENABLE_TESTS_INIT ON) + else() + set(CUTLASS_ENABLE_TESTS_INIT OFF) + endif() endif() set(CUTLASS_TEST_UNIT_ENABLE_WARNINGS OFF CACHE BOOL "Enable warnings on waived unit tests.") @@ -87,44 +138,46 @@ set(CUTLASS_ENABLE_EXAMPLES ${CUTLASS_ENABLE_EXAMPLES_INIT} CACHE BOOL "Enable C set(CUTLASS_ENABLE_TOOLS ${CUTLASS_ENABLE_TOOLS_INIT} CACHE BOOL "Enable CUTLASS Tools") set(CUTLASS_ENABLE_LIBRARY ${CUTLASS_ENABLE_LIBRARY_INIT} CACHE BOOL "Enable CUTLASS Library") set(CUTLASS_ENABLE_PROFILER ${CUTLASS_ENABLE_LIBRARY} CACHE BOOL "Enable CUTLASS Profiler") - -if(${CMAKE_PROJECT_NAME} STREQUAL ${PROJECT_NAME}) - set(CUTLASS_ENABLE_TESTS_INIT ${CUTLASS_ENABLE_LIBRARY}}) -else() - set(CUTLASS_ENABLE_TESTS_INIT OFF) -endif() +set(CUTLASS_ENABLE_PERFORMANCE ${CUTLASS_ENABLE_PROFILER} CACHE BOOL "Enable CUTLASS Performance") set(CUTLASS_ENABLE_TESTS ${CUTLASS_ENABLE_TESTS_INIT} CACHE BOOL "Enable CUTLASS Tests") - -if (CUTLASS_ENABLE_TESTS) - include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/googletest.cmake) +set(CUTLASS_ENABLE_GTEST_UNIT_TESTS ${CUTLASS_ENABLE_TESTS} CACHE BOOL "Enable CUTLASS GTest-based Unit Tests") +set(CUTLASS_USE_SYSTEM_GOOGLETEST OFF CACHE BOOL "Use system/external installation of GTest") +set(CUTLASS_USE_PACKED_TUPLE ON CACHE BOOL "If ON, make cute::tuple be new standard-layout tuple type; if OFF, use the original cute::tuple implementation that is _not_ standard-layout.") +if (CUTLASS_USE_PACKED_TUPLE) + list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTE_USE_PACKED_TUPLE=1) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DCUTLASS_USE_PACKED_TUPLE=1") + message(STATUS "Make cute::tuple be the new standard-layout tuple type") +elseif() + message(STATUS "Use the original cute::tuple implementation that is _not_ standard-layout") endif() +################################################################################ + set(CUTLASS_NVCC_ARCHS_SUPPORTED "") -if (NOT CUDA_VERSION VERSION_LESS 7.5) - list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 53) -endif() -if (NOT CUDA_VERSION VERSION_LESS 8.0) - list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 60 61) -endif() -if (NOT CUDA_VERSION VERSION_LESS 9.0) - list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 70) +if (CUDA_VERSION VERSION_GREATER_EQUAL 11.4) + list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 70 72 75 80 86 87) endif() -if (NOT CUDA_VERSION VERSION_LESS 9.2) - list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 72) +if (CUDA_VERSION VERSION_GREATER_EQUAL 11.8) + list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 89 90) endif() -if (NOT CUDA_VERSION VERSION_LESS 10.0) - list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 75) -endif() -if (NOT CUDA_VERSION VERSION_LESS 11.0) - list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 80) -endif() -if (NOT CUDA_VERSION VERSION_LESS 11.1 AND NOT CUDA_COMPILER MATCHES "[Cc]lang") - list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 86) +if (CUDA_VERSION VERSION_GREATER_EQUAL 12.0) + list(APPEND CUTLASS_NVCC_ARCHS_SUPPORTED 90a) endif() set(CUTLASS_NVCC_ARCHS ${CUTLASS_NVCC_ARCHS_SUPPORTED} CACHE STRING "The SM architectures requested.") set(CUTLASS_NVCC_ARCHS_ENABLED ${CUTLASS_NVCC_ARCHS} CACHE STRING "The SM architectures to build code for.") +# Find unsupported and deprecated compute capabilities +if (CUTLASS_NVCC_ARCHS_SUPPORTED) + set(CUTLASS_NVCC_ARCHS_UNSUPPORTED ${CUTLASS_NVCC_ARCHS}) + list(REMOVE_ITEM CUTLASS_NVCC_ARCHS_UNSUPPORTED ${CUTLASS_NVCC_ARCHS_SUPPORTED}) + if (CUTLASS_NVCC_ARCHS_UNSUPPORTED) + message(WARNING "Using unsupported or deprecated compute capabilities ${CUTLASS_NVCC_ARCHS_UNSUPPORTED}. Support may be removed in future versions.") + endif() +else() + message(WARNING "No supported compute capabilities for CUDA ${CUDA_VERSION}.") +endif() + # Special policy introduced in CMake 3.13 if (POLICY CMP0076) cmake_policy(SET CMP0076 NEW) @@ -133,6 +186,7 @@ endif() include(GNUInstallDirs) link_directories(${CUDA_TOOLKIT_ROOT_DIR}/lib64/stubs) +link_directories(${CUDA_TOOLKIT_ROOT_DIR}/lib64) ################################################################################################### # @@ -161,9 +215,12 @@ if(WIN32) set(gtest_force_shared_crt ON CACHE BOOL "Use shared (DLL) run-time lib even when Google Test is built as static lib" FORCE) endif() +set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DCUTLASS_VERSIONS_GENERATED") +set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -DCUTLASS_VERSIONS_GENERATED") + if (WIN32) - # Enable more warnings and treat as errors - list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=/W3 -Xcompiler=/WX) + # Enable more warnings. Add "-Xcompiler=/WX" to enable warnings as errors. + list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=/W3) # Disable warning on Unicode characters list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=/wd4819) @@ -177,7 +234,7 @@ if (${CUTLASS_NVCC_VERBOSE}) endif() # -# CUTLASS NAMESPACE +# CUTLASS NAMESPACE # set(CUTLASS_NAMESPACE "cutlass" CACHE STRING "Top level namespace of CUTLASS") @@ -186,15 +243,44 @@ set(CUTLASS_NVCC_EMBED_PTX ON CACHE BOOL "Embed compiled PTX into executables.") set(CUTLASS_NVCC_KEEP OFF CACHE BOOL "Keep intermediate files generated by NVCC.") set(CUTLASS_ENABLE_F16C OFF CACHE BOOL "Enable F16C x86 extensions in host code.") +################################################################################ # # CUTLASS generator cmake configuration # -set(CUTLASS_LIBRARY_OPERATIONS "all" CACHE STRING "Comma delimited list of operation name filters. Default '' means all operations are enabled.") -set(CUTLASS_LIBRARY_KERNELS "" CACHE STRING "Comma delimited list of kernel name filters. If unspecified, only the largest tile size is enabled. If 'all' is specified, all kernels are enabled.") -set(CUTLASS_LIBRARY_IGNORE_KERNELS "" CACHE STRING "Comma delimited list of kernel names to exclude from build.") -# Test Levels L0, L1, L2 -set(CUTLASS_TEST_LEVEL "0" CACHE STRING "Level of tests to compile.") +# Kernel unified filter file + +set(KERNEL_FILTER_FILE "" CACHE STRING "KERNEL FILTER FILE FULL PATH") + +if (KERNEL_FILTER_FILE AND NOT CUTLASS_LIBRARY_KERNELS) + # If a kernel filter file is specified, we want to generate and then + # filter on the entire kernel set, not the default kernel + # (sub)set. The user may have overridden CUTLASS_LIBRARY_KERNELS, in which + # case the resulting kernel set will be the intersection of the two + # options differenced against CUTLASS_LIBRARY_IGNORE_KERNELS. + set(CUTLASS_LIBRARY_KERNELS_INIT "*") +else() + set(CUTLASS_LIBRARY_KERNELS_INIT "") +endif() + +if (KERNEL_FILTER_FILE) + get_filename_component(KERNEL_FILTER_FILE "${KERNEL_FILTER_FILE}" ABSOLUTE) + set(KERNEL_FILTER_FILE "${KERNEL_FILTER_FILE}" CACHE STRING "KERNEL FILTER FILE FULL PATH" FORCE) +endif() + +set(SELECTED_KERNEL_LIST "selected" CACHE STRING "Name of the filtered kernel list") + +if(KERNEL_FILTER_FILE) + message(STATUS "Full path of filter file: ${KERNEL_FILTER_FILE}") +endif() + +set(CUTLASS_LIBRARY_OPERATIONS "all" CACHE STRING "Comma-delimited list of operation name filters. Default '' means all operations are enabled.") +set(CUTLASS_LIBRARY_KERNELS ${CUTLASS_LIBRARY_KERNELS_INIT} CACHE STRING "Comma-delimited list of kernel name filters. If unspecified, only the largest tile size is enabled. If the string 'all' is specified, all kernels are enabled.") +set(CUTLASS_LIBRARY_IGNORE_KERNELS "" CACHE STRING "Comma-delimited list of kernels to exclude from build. This option ONLY takes effect if CUTLASS_LIBRARY_KERNELS is set.") +set(CUTLASS_LIBRARY_EXCLUDE_KERNELS "" CACHE STRING "Comma-delimited list of kernels to exclude from build. This option always takes effect, whether or not CUTLASS_LIBRARY_KERNELS is set. It also can exclude kernels from the filter file (see KERNEL_FILTER_FILE).") +set(CUTLASS_LIBRARY_INSTANTIATION_LEVEL "" CACHE STRING "Instantiation level for SM90 kernels. Set to `max` and make sure CUTLASS_LIBRARY_KERNELS is non-empty to stamp all possible kernel configurations.") + +################################################################################ set(CUTLASS_TEST_ENABLE_CACHED_RESULTS ON CACHE BOOL "Enable caching and reuse of test results in unit tests") @@ -214,6 +300,8 @@ if (CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED) list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED=1) endif() +################################################################################ + # # CUDA 10.1 introduces "mma" in PTX performing collective matrix multiply operations. # @@ -231,6 +319,15 @@ list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_DEBUG_TRACE_LEVEL=${CUTLASS_DEBUG_ set(CUTLASS_ENABLE_TENSOR_CORE_MMA ${CUTLASS_ENABLE_TENSOR_CORE_MMA_DEFAULT} CACHE BOOL "Enable PTX mma instruction for collective matrix multiply operations.") +set(CUTLASS_ENABLE_SM90_EXTENDED_MMA_SHAPES OFF CACHE BOOL + "Enable an extended set of SM90 WGMMA instruction shapes (may lead to increased compilation times)") +if(CUTLASS_ENABLE_SM90_EXTENDED_MMA_SHAPES) + message(STATUS "Enabled extended SM90 WGMMA instruction shapes") + list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +endif() + +set(CUTLASS_SKIP_REDUCTION_INIT OFF CACHE BOOL "Disable init reduction workspace") + # # NOTE: running with asan and CUDA requires the following environment variable: # @@ -258,10 +355,53 @@ if(CUTLASS_NVCC_EMBED_PTX) list(APPEND CUTLASS_CUDA_CLANG_FLAGS --cuda-include-ptx=all) endif() +if (CUTLASS_SKIP_REDUCTION_INIT) + list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_SKIP_REDUCTION_INIT=1) +endif() + if (CUTLASS_ENABLE_TENSOR_CORE_MMA) list(APPEND CUTLASS_CUDA_FLAGS -DCUTLASS_ENABLE_TENSOR_CORE_MMA=1) endif() +set(CUTLASS_PROFILER_DISABLE_REFERENCE OFF CACHE BOOL "Disable compilation of reference kernels in the CUTLASS profiler.") +if (CUTLASS_PROFILER_DISABLE_REFERENCE) + list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_PROFILER_DISABLE_REFERENCE=1) +endif() + +if (CUTLASS_ENABLE_GDC_FOR_SM90) + message(STATUS "Grid Dependency Control (GDC) is enabled for SM90 kernels (required for programmatic dependent launches).") + list(APPEND CUTLASS_CUDA_NVCC_FLAGS -DCUTLASS_ENABLE_GDC_FOR_SM90=1) +endif() + +set(CUTLASS_ENABLE_SYNCLOG OFF CACHE BOOL "Enable synchronization event logging for race condition debugging. WARNING: This redefines __syncthreads() and __syncwarp() in all downstream code!") + +if (CUTLASS_ENABLE_SYNCLOG) + set(CMAKE_CUDA_SEPARABLE_COMPILATION ON) + string(APPEND CMAKE_CXX_FLAGS " -DCUTLASS_ENABLE_SYNCLOG=1") + string(APPEND CMAKE_CUDA_FLAGS " -DCUTLASS_ENABLE_SYNCLOG=1") +endif() + + + +# Warnings-as-error exceptions and warning suppressions for Clang builds +if (CUTLASS_CLANG_HOST_COMPILE) + + set(FLAGS_TO_ADD + "-Wno-error=implicit-int-conversion" + "-Wno-error=pass-failed" + "-Wno-error=inconsistent-missing-override" + "-Wno-sign-conversion" + "-Wno-unused-parameter" + ) + + foreach(FLAG ${FLAGS_TO_ADD}) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} ${FLAG}") + list(APPEND CUTLASS_CUDA_NVCC_FLAGS "${FLAG}") + list(APPEND CUTLASS_CUDA_CLANG_FLAGS "${FLAG}") + endforeach() + +endif() + if (NOT MSVC AND CUTLASS_NVCC_KEEP) # MSVC flow handles caching already, but for other generators we handle it here. set(CUTLASS_NVCC_KEEP_DIR ${CMAKE_CURRENT_BINARY_DIR}/tmp CACHE PATH "Location to store NVCC scratch files") @@ -272,15 +412,26 @@ endif() if (CUTLASS_ENABLE_F16C AND NOT CMAKE_CROSSCOMPILING) list(APPEND CUTLASS_CUDA_FLAGS -DCUTLASS_ENABLE_F16C=1) - if ((CMAKE_CXX_COMPILER_ID MATCHES "GNU") OR (CMAKE_CXX_COMPILER_ID MATCHES "Clang")) + if (CUTLASS_GNU_HOST_COMPILE OR CUTLASS_CLANG_HOST_COMPILE) list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=-mf16c) - elseif((CMAKE_CXX_COMPILER_ID MATCHES "MSVC")) + elseif(CUTLASS_MSVC_HOST_COMPILE) list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=/arch:AVX2) endif() endif() -list(APPEND CUTLASS_CUDA_NVCC_FLAGS $<$:-Xcompiler=-Wconversion>) -list(APPEND CUTLASS_CUDA_NVCC_FLAGS $<$:-Xcompiler=-fno-strict-aliasing>) +if (CUTLASS_ENABLE_OPENMP_TESTS) + find_package(OpenMP) + if(OpenMP_CXX_FOUND) + list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=${OpenMP_CXX_FLAGS}) + else() + message(WARNING "CUTLASS_ENABLE_OPENMP_TESTS set but OpenMP not found.") + endif() +endif() + +if(UNIX) + list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=-Wconversion) + list(APPEND CUTLASS_CUDA_NVCC_FLAGS -Xcompiler=-fno-strict-aliasing) +endif() # Don't leak lineinfo in release builds if (NOT CMAKE_BUILD_TYPE MATCHES "Release") @@ -288,28 +439,13 @@ if (NOT CMAKE_BUILD_TYPE MATCHES "Release") list(APPEND CUTLASS_CUDA_NVCC_FLAGS -lineinfo) endif() -#Report CUDA build flags -if (CUDA_COMPILER MATCHES "[Cc]lang") - if(CUTLASS_CUDA_CLANG_FLAGS) - message(STATUS "Using CLANG flags: ${CUTLASS_CUDA_CLANG_FLAGS}") - endif() -else() - if(CUTLASS_CUDA_NVCC_FLAGS) - message(STATUS "Using NVCC flags: ${CUTLASS_CUDA_NVCC_FLAGS}") - endif() -endif() - -if(CUDA_COMPILER MATCHES "[Cc]lang") - if( NOT CMAKE_CXX_COMPILER_ID MATCHES "Clang" ) +if (CUTLASS_CLANG_DEVICE_COMPILE) + if (NOT CUTLASS_CLANG_HOST_COMPILE) message(FATAL_ERROR "Clang CUDA compilation requires Clang CXX compilation. Currently CMAKE_CXX_COMPILER is ${CMAKE_CXX_COMPILER_ID}" ) endif() - if (CMAKE_CXX_COMPILER_VERSION VERSION_LESS 7.0) - message(FATAL_ERROR "Clang 7.0+ required for GPU compilation") - endif() - - # There are numerous Clang versions that can work with each CUDA toolkit and the - # the checks are not very useful so we are turning them off and using testing to + # There are numerous Clang versions that can work with each CUDA toolkit and the + # the checks are not very useful so we are turning them off and using testing to # ensure the various combinations work properly. list(APPEND CUTLASS_CUDA_CLANG_FLAGS --cuda-path=${CUDA_TOOLKIT_ROOT_DIR}) @@ -320,71 +456,102 @@ if(CUDA_COMPILER MATCHES "[Cc]lang") list(APPEND CUTLASS_CUDA_CLANG_FLAGS -mllvm -unroll-threshold=5000) list(APPEND CUTLASS_CUDA_CLANG_FLAGS -Wno-unused-command-line-argument) - string(REPLACE "." ";" CUDA_VERSION_PARTS ${CMAKE_CUDA_COMPILER_VERSION}) - list(GET CUDA_VERSION_PARTS 0 CUDA_VERSION_MAJOR) - list(GET CUDA_VERSION_PARTS 1 CUDA_VERSION_MINOR) list(APPEND CUTLASS_CUDA_CLANG_FLAGS -D__CUDACC_VER_MAJOR__=${CUDA_VERSION_MAJOR} -D__CUDACC_VER_MINOR__=${CUDA_VERSION_MINOR}) - # needed for libcublasLt.so in case it's installed in the same location as libcudart.so # dynamic linker can find it if linker sets RPATH (forced by --disable-new-tags) # Otherwise linker uses RUNPATH and that does not propagate to loaded libs. list(APPEND CUTLASS_CUDA_CLANG_FLAGS -Wl,--disable-new-dtags) link_libraries(nvidia::cudart) + link_libraries(nvidia::cuda_driver) + +endif() + +#Report CUDA build flags +if (CUTLASS_CLANG_DEVICE_COMPILE AND CUTLASS_CUDA_CLANG_FLAGS) + set(__FLAG_GROUP Clang) + set(__FLAG_LIST CUTLASS_CUDA_CLANG_FLAGS) +else(CUTLASS_NVCC_DEVICE_COMPILE AND CUTLASS_CUDA_NVCC_FLAGS) + set(__FLAG_GROUP NVCC) + set(__FLAG_LIST CUTLASS_CUDA_NVCC_FLAGS) +endif() + +set(__FLAG_DISPLAY_STRING "") +set(__FLAG_DISPLAY_SEPARATOR) +list(JOIN ${__FLAG_LIST} "\n " __FLAG_DISPLAY_STRING) +message(STATUS "Using the following ${__FLAG_GROUP} flags: \n ${__FLAG_DISPLAY_STRING}") + +# Known gcc 8.1-8.3 SFINAE issue (fixed in gcc 8.4), check https://gcc.gnu.org/bugzilla/show_bug.cgi?id=87748 +# Also see https://github.com/NVIDIA/nccl/issues/835 for nvtx3.hpp +if (CUTLASS_GNU_HOST_COMPILE AND CMAKE_CXX_COMPILER_VERSION VERSION_GREATER_EQUAL 8.1 AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS_EQUAL 8.3) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -DNVTX3_USE_CHECKED_OVERLOADS_FOR_GET=0") + set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -DNVTX3_USE_CHECKED_OVERLOADS_FOR_GET=0") endif() -# Support for 128-bit integers if using NVIDIA C++ compiler +# Support for 128-bit integers if using NVIDIA C++ compiler if (${CMAKE_CXX_COMPILER_ID} MATCHES "PGI" OR ${CMAKE_CXX_COMPILER_ID} MATCHES "NVHPC") set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} -Mint128 ") endif() -if (CMAKE_VERSION VERSION_GREATER_EQUAL 3.18) - # CMake 3.18 added support for CUDA_ARCHITECTURES target property. We will use this - # property for CMake 3.18+, so we request the NEW behavior for correct compatibility. - # https://cmake.org/cmake/help/v3.18/policy/CMP0104.html#policy:CMP0104 - cmake_policy(SET CMP0104 NEW) +# CMake 3.18 added support for CUDA_ARCHITECTURES target property. We will use this +# property for CMake 3.18+, so we request the NEW behavior for correct compatibility. +# https://cmake.org/cmake/help/v3.18/policy/CMP0104.html#policy:CMP0104 +cmake_policy(SET CMP0104 NEW) + +if (MSVC) + + # MSVC by default does not apply the correct __cplusplus version as specified by the C++ standard + # because MSVC is not a completely compliant implementation. This option forces MSVC to use the + # appropriate value given the requested --std option. This fixes a compilation issue mismatch + # between GCC/Clang and MSVC. + # + # error : a constexpr function cannot have a nonliteral return type "dim3" + # + # See https://developercommunity.visualstudio.com/t/msvc-incorrectly-defines-cplusplus/139261 + + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /Zc:__cplusplus") + set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler /Zc:__cplusplus") + +endif() + +# Some tests require this build option in order to link. +if (MSVC) + set(CMAKE_CXX_FLAGS "${CMAKE_CXX_FLAGS} /bigobj") + set(CMAKE_CUDA_FLAGS "${CMAKE_CUDA_FLAGS} -Xcompiler /bigobj") endif() function(cutlass_apply_cuda_gencode_flags TARGET) + set(options) + set(oneValueArgs) + set(multiValueArgs SM_ARCHS) + cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) + + if (__SM_ARCHS) + set(ARCHS_ENABLED ${__SM_ARCHS}) + else() + set(ARCHS_ENABLED ${CUTLASS_NVCC_ARCHS_ENABLED}) + endif() - set(NVCC_FLAGS) - set(CLANG_FLAGS) set(__CMAKE_CUDA_ARCHS) - foreach(ARCH ${CUTLASS_NVCC_ARCHS_ENABLED}) - list(APPEND CLANG_FLAGS --cuda-gpu-arch=sm_${ARCH}) + foreach(ARCH ${ARCHS_ENABLED}) set(CODES) if(CUTLASS_NVCC_EMBED_CUBIN) - list(APPEND CODES sm_${ARCH}) list(APPEND __CMAKE_CUDA_ARCHS ${ARCH}-real) endif() - if(CUTLASS_NVCC_EMBED_PTX) - list(APPEND CODES compute_${ARCH}) + if(CUTLASS_NVCC_EMBED_PTX AND NOT CUTLASS_CLANG_DEVICE_COMPILE) + # If we're using clang for device compilation, the ptx is inserted + # via another command line option and the `-virtual` flags will cause an error. list(APPEND __CMAKE_CUDA_ARCHS ${ARCH}-virtual) endif() list(JOIN CODES "," CODES_STR) - list(APPEND NVCC_FLAGS -gencode=arch=compute_${ARCH},code=[${CODES_STR}]) endforeach() - if (CUDA_COMPILER MATCHES "[Cc]lang") - target_compile_options( - ${TARGET} - PRIVATE - $<$:${CLANG_FLAGS}> - ) - elseif(CMAKE_VERSION GREATER_EQUAL 3.18) - set_property(TARGET ${TARGET} PROPERTY CUDA_ARCHITECTURES ${__CMAKE_CUDA_ARCHS}) - else() - target_compile_options( - ${TARGET} - PRIVATE - $<$:${NVCC_FLAGS}> - ) - endif() + set_property(TARGET ${TARGET} PROPERTY CUDA_ARCHITECTURES ${__CMAKE_CUDA_ARCHS}) endfunction() -# Cache the flags so they are available when the function below is called anywhere globally. +# Cache the flags so they are available when the function below is called anywhere globally. set(__CUTLASS_CUDA_FLAGS ${CUTLASS_CUDA_FLAGS} CACHE INTERNAL "") set(__CUTLASS_CUDA_FLAGS_RELEASE ${CUTLASS_CUDA_FLAGS_RELEASE} CACHE INTERNAL "") @@ -401,8 +568,8 @@ set(__CUTLASS_CUDA_NVCC_FLAGS_DEBUG ${CUTLASS_CUDA_NVCC_FLAGS_DEBUG} CACHE INTER function(cutlass_apply_standard_compile_options TARGET) - if(CUDA_COMPILER MATCHES "[Cc]lang") - set(CUDA_COMPILE_LANGUAGE CXX) + if(CUTLASS_CLANG_DEVICE_COMPILE) + set(CUDA_COMPILE_LANGUAGE CUDA) set(_FLAGS ${__CUTLASS_CUDA_FLAGS} ${__CUTLASS_CUDA_CLANG_FLAGS}) set(_FLAGS_RELEASE ${__CUTLASS_CUDA_FLAGS_RELEASE} ${__CUTLASS_CUDA_CLANG_FLAGS_RELEASE}) set(_FLAGS_RELWITHDEBINFO ${__CUTLASS_CUDA_FLAGS_RELWITHDEBINFO} ${__CUTLASS_CUDA_CLANG_FLAGS_RELWITHDEBINFO}) @@ -434,7 +601,8 @@ endfunction() # GLOB for CUTLASS header files. Should we use a static list instead? file(GLOB_RECURSE CUTLASS_INCLUDE RELATIVE ${CMAKE_CURRENT_SOURCE_DIR} include/cutlass/*.h) -file(GLOB_RECURSE CUTLASS_CUTLASS RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}/include include/cutlass/*.h) +file(GLOB_RECURSE CUTLASS_CUTLASS RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}/include include/cutlass/*.h include/cutlass/*.hpp include/cutlass/*.inl) +file(GLOB_RECURSE CUTLASS_CUTE RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}/include include/cute/*.h*) file(GLOB_RECURSE CUTLASS_NVRTC RELATIVE ${CMAKE_CURRENT_SOURCE_DIR}/test test/unit/nvrtc/kernel/*.h) ################################################################################################### @@ -459,7 +627,10 @@ set(CUTLASS_TOOLS_UTIL_INCLUDE_DIR ${CMAKE_CURRENT_SOURCE_DIR}/tools/util/includ include_directories(${CUTLASS_INCLUDE_DIR}) target_compile_features(CUTLASS INTERFACE cxx_std_11) -target_compile_definitions(CUTLASS INTERFACE CUTLASS_NAMESPACE=${CUTLASS_NAMESPACE}) + +if (NOT CUTLASS_NAMESPACE STREQUAL "cutlass") + target_compile_definitions(CUTLASS INTERFACE CUTLASS_NAMESPACE=${CUTLASS_NAMESPACE}) +endif() if (NOT DEFINED CUTLASS_REVISION) @@ -481,8 +652,8 @@ if (NOT DEFINED CUTLASS_REVISION) endif() configure_file( - ${CMAKE_CURRENT_SOURCE_DIR}/cmake/version.h.in - ${CMAKE_CURRENT_BINARY_DIR}/include/cutlass/version.h + ${CMAKE_CURRENT_SOURCE_DIR}/cmake/version_extended.h.in + ${CMAKE_CURRENT_BINARY_DIR}/include/cutlass/version_extended.h @ONLY) target_include_directories( @@ -491,6 +662,12 @@ target_include_directories( $ $ $ + ) + +# Mark CTK headers as system to supress warnings from them +target_include_directories( + CUTLASS + SYSTEM INTERFACE $ ) @@ -543,6 +720,7 @@ if(NOT WIN32) "-Wl,-rpath,'$ORIGIN/../lib'" "-Wl,-rpath,'${CUDA_TOOLKIT_ROOT_DIR}/lib64'" "-Wl,-rpath,'${CUDA_TOOLKIT_ROOT_DIR}/lib'" + ${CMAKE_DL_LIBS} ) endif() @@ -550,6 +728,15 @@ endif() include(CTest) enable_testing() + +if (CUTLASS_ENABLE_GTEST_UNIT_TESTS) + if (CUTLASS_USE_SYSTEM_GOOGLETEST) + find_package(GTest REQUIRED) + else() + include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/googletest.cmake) + endif() +endif() + if (NOT TARGET test_all) add_custom_target(test_all) endif() @@ -569,6 +756,9 @@ install(DIRECTORY DESTINATION ${CUTLASS_TEST_INSTALL_PREFIX}/ctest) ################################################################################ +set(CUTLASS_ENABLE_CUBLAS OFF CACHE BOOL "cuBLAS usage for tests") +set(CUTLASS_ENABLE_CUDNN OFF CACHE BOOL "cuDNN usage for tests") + include(${CMAKE_CURRENT_SOURCE_DIR}/cuBLAS.cmake) if (CUTLASS_ENABLE_CUBLAS) @@ -583,35 +773,54 @@ endif() ################################################################################ -set(CUTLASS_CTEST_TEMPLATE_FILE ${CMAKE_CURRENT_LIST_DIR}/cmake/CTestTestfile.config.cmake) +set(CUTLASS_DEFAULT_ACTIVE_TEST_SETS "default" CACHE STRING "Default + activated test sets. In `make test` mode, this string determines the + active set of tests. In `ctest` mode, this value can be overriden + with CUTLASS_TEST_SETS environment variable when running the ctest + executable.") + +file(MAKE_DIRECTORY "${CMAKE_BINARY_DIR}/${CMAKE_INSTALL_BINDIR}") +set(CUTLASS_CTEST_TEMPLATE_FILE ${CMAKE_CURRENT_LIST_DIR}/cmake/CTestTestfile.configure.cmake) set(CUTLASS_CTEST_GENERATED_FILES "" CACHE INTERNAL "") function(cutlass_add_executable_tests NAME TARGET) -# -# Generates test rules for `make test`, `make test_all`, and `ctest` invoked from either the +# +# Generates test rules for `make test`, `make test_all`, and `ctest` invoked from either the # or the / after installation. -# +# # NAME: The base name for the test. Can be run with `make ` or `ctest -R 'c'`. # TARGET: The target corresponding to the executable under test. # DISABLE_EXECUTABLE_INSTALL_RULE: An option, if given, that disables creating an install rule for TARGET. # DEPENDS: A list of targets or files on which this test is dependent. # DEPENDEES: A list of targets which should depend on this test. # TEST_COMMAND_OPTIONS: A list of variables (i.e. by reference params) which contain command line arguments -# to pass to the test executable. A unique test with suffix _0, _1, ... is generated for each set of +# to pass to the test executable. A unique test is generated for each set of # options given. If this option is not used, a single test with no arguments is generated. +# TEST_COMMAND_OPTIONS_PREFIX: If provided, is added as a prefix to each TEST_COMMAND_OPTIONS value for +# generating the full variable name to be referenced. # RESULT_CACHE_FILE: A file to be installed alongside the test executable with pre-computed # test results to speed up test runtime. -# +# TEST_SETS_SUPPORTED: A list of test set names these tests support. +# - set(options DISABLE_EXECUTABLE_INSTALL_RULE) - set(oneValueArgs DISABLE_TESTS RESULT_CACHE_FILE) - set(multiValueArgs DEPENDS DEPENDEES TEST_COMMAND_OPTIONS) + set(options DISABLE_EXECUTABLE_INSTALL_RULE DO_NOT_LOWERCASE_TEST_NAME) + set(oneValueArgs DISABLE_TESTS RESULT_CACHE_FILE TEST_COMMAND_OPTIONS_PREFIX) + set(multiValueArgs DEPENDS DEPENDEES TEST_COMMAND_OPTIONS TEST_SETS_SUPPORTED) cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) if (NOT DEFINED __DISABLE_TESTS) set(__DISABLE_TESTS OFF) endif() + set(TEST_EXE $) + set(TEST_EXE_WORKING_DIRECTORY ./${CMAKE_INSTALL_BINDIR}) + + if (NOT DEFINED __TEST_SETS_SUPPORTED) + set(__TEST_SETS_SUPPORTED ${CUTLASS_DEFAULT_ACTIVE_TEST_SETS}) + endif() + + set(TEST_SETS_SUPPORTED ${__TEST_SETS_SUPPORTED}) + if (__RESULT_CACHE_FILE) add_custom_command( @@ -624,9 +833,9 @@ function(cutlass_add_executable_tests NAME TARGET) endif() if (NOT __DISABLE_EXECUTABLE_INSTALL_RULE AND CUTLASS_INSTALL_TESTS) - + # file(RELATIVE_PATH CMAKE_CURRENT_BINARY_RELATIVE_DIR ${CMAKE_BINARY_DIR} ${CMAKE_CURRENT_BINARY_DIR}) - + install( TARGETS ${TARGET} RUNTIME DESTINATION ${CUTLASS_TEST_INSTALL_BINDIR} @@ -640,7 +849,7 @@ function(cutlass_add_executable_tests NAME TARGET) ) endif() - + endif() if (NOT __TEST_COMMAND_OPTIONS) @@ -648,7 +857,6 @@ function(cutlass_add_executable_tests NAME TARGET) endif() list(LENGTH __TEST_COMMAND_OPTIONS CMD_COUNT) - set(CMD_IDX 0) if (CMD_COUNT GREATER 1) add_custom_target(${NAME} DEPENDS ${TARGET} ${__DEPENDS}) @@ -657,74 +865,96 @@ function(cutlass_add_executable_tests NAME TARGET) endforeach() endif() - foreach(CMD_OPTIONS ${__TEST_COMMAND_OPTIONS}) + if (CUTLASS_INSTALL_TESTS) + + set(_INLINE_PER_TEST_CODE) + + file(READ "${PROJECT_SOURCE_DIR}/cmake/CTestTestfile.test.configure.cmake" _INLINE_PER_TEST_CODE_TEMPLATE) + + endif() + + set(TEST_GROUP_NAME ${NAME}) + + # To run the tests from an install package with tests enabled, we need to generate test files + # that don't rely on the current directory structure in build. + + set(TEST_NAME c${NAME}) + set(TEST_GEN_DIR ${CMAKE_CURRENT_BINARY_DIR}/ctest/${TEST_NAME}) + file(MAKE_DIRECTORY ${TEST_GEN_DIR}) + + set(TEST_EXE_PATH $) + set(TEST_USE_EXTENDED_FORMAT ON) + configure_file("${CUTLASS_CTEST_TEMPLATE_FILE}" "${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.cmake" @ONLY) + + set(TEST_EXE_PATH $) + set(TEST_USE_EXTENDED_FORMAT OFF) # ctest does not support extended add_test format. + configure_file("${CUTLASS_CTEST_TEMPLATE_FILE}" "${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.install.cmake.in" @ONLY) + + foreach(CMD_OPTIONS_VAR IN LISTS __TEST_COMMAND_OPTIONS) if (CMD_COUNT GREATER 1) - set(TEST_NAME ${NAME}_${CMD_IDX}) + set(TESTCASE_NAME "${NAME}_${CMD_OPTIONS_VAR}") else() - set(TEST_NAME ${NAME}) + set(TESTCASE_NAME "${NAME}") endif() - # The following rigmarole is needed to deal with spaces and possible quotes in + if (NOT __DO_NOT_LOWERCASE_TEST_NAME) + string(TOLOWER "${TESTCASE_NAME}" TESTCASE_NAME) + endif() + + # The following rigmarole is needed to deal with spaces and possible quotes in # command line arguments. The options are passed "by reference" as the actual # variable names holding the real options. We then expand these in a way that - # preserves any quotes. Note, they have to be in this order for it to work for + # preserves any quotes. Note, they have to be in this order for it to work for # all the use cases below. - set(CMD_OPTIONS ${${CMD_OPTIONS}}) - list(JOIN CMD_OPTIONS " " TEST_COMMAND_OPTIONS) - separate_arguments(CMD_OPTIONS) - + set(TEST_COMMAND_OPTIONS ${${__TEST_COMMAND_OPTIONS_PREFIX}${CMD_OPTIONS_VAR}}) + list(JOIN TEST_COMMAND_OPTIONS " " TEST_COMMAND_OPTIONS) + separate_arguments(TEST_COMMAND_OPTIONS) + add_custom_target( - ${TEST_NAME} + ${TESTCASE_NAME} COMMAND - ${CUTLASS_TEST_EXECUTION_ENVIRONMENT} $ ${CMD_OPTIONS} + ${CUTLASS_TEST_EXECUTION_ENVIRONMENT} $ ${TEST_COMMAND_OPTIONS} DEPENDS ${TARGET} ) if (CMD_COUNT GREATER 1) - add_dependencies(${NAME} ${TEST_NAME}) + add_dependencies(${NAME} ${TESTCASE_NAME}) endif() foreach(DEPENDEE ${__DEPENDEES}) - add_dependencies(${DEPENDEE} ${TEST_NAME}) + add_dependencies(${DEPENDEE} ${TESTCASE_NAME}) endforeach() - add_test( - NAME c${TEST_NAME} - COMMAND ${CUTLASS_TEST_EXECUTION_ENVIRONMENT} $ ${CMD_OPTIONS} - ) + set(TESTCASE_NAME c${TESTCASE_NAME}) + string(CONFIGURE "${_INLINE_PER_TEST_CODE_TEMPLATE}" _TEST_CODE @ONLY) + file(APPEND "${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.cmake" "${_TEST_CODE}") + file(APPEND "${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.install.cmake.in" "${_TEST_CODE}") - set_tests_properties(c${TEST_NAME} PROPERTIES DISABLED ${__DISABLE_TESTS}) + endforeach() - if (CUTLASS_INSTALL_TESTS) + # The following line imports the tests for immediate run via `make test`. - # To run the tests from an install package with tests enabled, we need to generate test files - # that don't rely on the current directory structure in build. + include(${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.cmake) - set(TEST_NAME c${TEST_NAME}) - set(TEST_EXE $) - set(TEST_EXE_WORKING_DIRECTORY ./${CMAKE_INSTALL_BINDIR}) - configure_file("${CUTLASS_CTEST_TEMPLATE_FILE}" "${CMAKE_PROJECT_DIR}${CMAKE_CURRENT_BINARY_DIR}/CTestTestfile.${TEST_NAME}.config.cmake" @ONLY) + set(CUTLASS_CTEST_GENERATED_FILES ${CUTLASS_CTEST_GENERATED_FILES};ctest/${TEST_NAME}/CTestTestfile.${TEST_NAME}.cmake CACHE INTERNAL "") - file(GENERATE - OUTPUT "${CMAKE_PROJECT_DIR}${CMAKE_CURRENT_BINARY_DIR}/CTestTestfile.${TEST_NAME}.cmake" - INPUT "${CMAKE_PROJECT_DIR}${CMAKE_CURRENT_BINARY_DIR}/CTestTestfile.${TEST_NAME}.config.cmake" - ) - - install( - FILES "${CMAKE_PROJECT_DIR}${CMAKE_CURRENT_BINARY_DIR}/CTestTestfile.${TEST_NAME}.cmake" - DESTINATION ${CUTLASS_TEST_INSTALL_PREFIX}/ctest/ - ) - - set(CUTLASS_CTEST_GENERATED_FILES ${CUTLASS_CTEST_GENERATED_FILES};ctest/CTestTestfile.${TEST_NAME}.cmake CACHE INTERNAL "") - - endif() + if (CUTLASS_INSTALL_TESTS) - math(EXPR CMD_IDX "${CMD_IDX} + 1") + file(GENERATE + OUTPUT "${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.install.cmake" + INPUT "${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.install.cmake.in" + ) - endforeach() + install( + FILES "${TEST_GEN_DIR}/CTestTestfile.${TEST_NAME}.install.cmake" + DESTINATION ${CUTLASS_TEST_INSTALL_PREFIX}/ctest/${TEST_NAME} + RENAME CTestTestfile.${TEST_NAME}.cmake + ) + + endif() endfunction() @@ -732,8 +962,9 @@ if (CUTLASS_ENABLE_TOOLS) add_subdirectory(tools) if (CUTLASS_ENABLE_PROFILER) add_dependencies(test_all test_profiler) - endif() + endif() endif() + if (CUTLASS_ENABLE_EXAMPLES) add_subdirectory(examples) add_dependencies(test_all test_examples) @@ -741,52 +972,61 @@ endif() if (CUTLASS_ENABLE_TESTS) add_subdirectory(test) + if (CUTLASS_ENABLE_GTEST_UNIT_TESTS) add_dependencies(test_all test_unit) + endif() endif() if (CUTLASS_INSTALL_TESTS) - file(MAKE_DIRECTORY "${CMAKE_BINARY_DIR}/cmake") + file(MAKE_DIRECTORY "${CMAKE_BINARY_DIR}/ctest") + + file(WRITE "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" "# Generated File\n\n") + file(APPEND "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" "cmake_policy(SET CMP0057 NEW) # Allow IN_LIST for if()\n\n") + file(APPEND "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" "if (NOT DEFINED ENV{CUTLASS_TEST_SETS})\n") + file(APPEND "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" " set(ENV{CUTLASS_TEST_SETS} ${CUTLASS_DEFAULT_ACTIVE_TEST_SETS})\n") + file(APPEND "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" "endif()\n\n") - file(WRITE "${CMAKE_BINARY_DIR}/cmake/CTestTestfile.cmake" "# Generated File\n") foreach(GENERATED_FILE ${CUTLASS_CTEST_GENERATED_FILES}) - file(APPEND "${CMAKE_BINARY_DIR}/cmake/CTestTestfile.cmake" "include(${GENERATED_FILE})\n") + file(APPEND "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" "include(${GENERATED_FILE})\n") endforeach() install( - FILES "${CMAKE_BINARY_DIR}/cmake/CTestTestfile.cmake" + FILES "${CMAKE_BINARY_DIR}/ctest/CTestTestfile.cmake" DESTINATION "${CUTLASS_TEST_INSTALL_PREFIX}/" ) endif() -#? install( -#? FILES ${CMAKE_BINARY_DIR}/CTestTestfile.cmake -#? DESTINATION ${CUTLASS_TEST_INSTALL_PREFIX}/ -#? ) -#? -#? install( -#? DIRECTORY -#? ${CMAKE_BINARY_DIR}/tools -#? ${CMAKE_BINARY_DIR}/test -#? DESTINATION ${CUTLASS_TEST_INSTALL_PREFIX}/ -#? FILES_MATCHING PATTERN "CTestTestfile.cmake" -#? ) - ################################################################################ +include(CMakePackageConfigHelpers) + +write_basic_package_version_file( + ${CMAKE_CURRENT_BINARY_DIR}/NvidiaCutlassConfigVersion.cmake + COMPATIBILITY AnyNewerVersion) + +configure_file( + ${CMAKE_CURRENT_SOURCE_DIR}/cmake/NvidiaCutlassConfig.cmake.in + ${CMAKE_CURRENT_BINARY_DIR}/NvidiaCutlassConfig.cmake + @ONLY + ) + install( - FILES ${CMAKE_CURRENT_SOURCE_DIR}/cmake/NvidiaCutlassConfig.cmake - DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/ + FILES + ${CMAKE_CURRENT_BINARY_DIR}/NvidiaCutlassConfig.cmake + ${CMAKE_CURRENT_BINARY_DIR}/NvidiaCutlassConfigVersion.cmake + DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/NvidiaCutlass/ ) install( EXPORT NvidiaCutlass NAMESPACE nvidia::cutlass:: - DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/ + DESTINATION ${CMAKE_INSTALL_LIBDIR}/cmake/NvidiaCutlass/ FILE NvidiaCutlassTargets.cmake ) ################################################################################ include(${CMAKE_CURRENT_SOURCE_DIR}/cmake/NvidiaCutlassPackageConfig.cmake) + diff --git a/CONTRIBUTORS.md b/CONTRIBUTORS.md index dccfbda6fc..538bb65843 100644 --- a/CONTRIBUTORS.md +++ b/CONTRIBUTORS.md @@ -1,65 +1,87 @@ -![ALT](/media/images/gemm-hierarchy-with-epilogue-no-labels.png "CUTLASS") +![ALT](./media/images/gemm-hierarchy-with-epilogue-no-labels.png "CUTLASS") -[README](/README.md#documentation) > **Contributors** +[README](./README.md#documentation) > **Contributors** # CUTLASS Developers and Contributors This is the official list of CUTLASS developers and contributors. ## DEVELOPERS -Andrew Kerr -Haicheng Wu -Manish Gupta -Dustyn Blasig -Pradeep Ramani -Naila Farooqui -Piotr Majcher -Paul Springer -Jin Wang -Aniket Shivam -Chinmay Talegaonkar -Shang Zhang -Scott Yokim -Markus Hohnerbach -Aditya Atluri -David Tanner -Manikandan Ananth +Vijay Thakkar
+Pradeep Ramani
+Cris Cecka
+Aniket Shivam
+Jack Kosaian
+Mark Hoemmen
+Richard Cai
+Honghao Lu
+Ethan Yan
+Haicheng Wu
+Andrew Kerr
+Dustyn Blasig
+Fengqi Qiao
+Duane Merrill
+Yujia Zhai
+Rawn Henry
+Sergey Klevtsov
+Shang Zhang
+Piotr Majcher
+Paul Springer
+Markus Hohnerbach
+Jin Wang
+Aditya Atluri
+ +## CuTe +Cris Cecka
+Vijay Thakkar
## CUTLASS Product Manager -Matthew Nicely - -## CONTRIBUTORS -Timothy Costa -Julien Demouth -Brian Fahs -Michael Goldfarb -Mostafa Hagog -Fei Hu -Alan Kaatz -Tina Li -Timmy Liu -Duane Merrill -Kevin Siu -Markus Tavenrath -John Tran -Vicki Wang -Junkai Wu -Fung Xie -Albert Xu -Jack Yang -Xiuxia Zhang -Nick Zhao +Matthew Nicely
-## ACKNOWLEDGEMENTS +## Former CUTLASS Developers +Manish Gupta
+Naila Farooqui
+David Tanner
+Manikandan Ananth
+Zhaodong Chen
+Chinmay Talegaonkar
-Girish Bharambe -Cris Cecka -Luke Durant -Olivier Giroux -Stephen Jones -Rishkul Kulkarni -Bryce Lelbach -Joel McCormack -Kyrylo Perelygin +## CONTRIBUTORS +Timothy Costa
+Julien Demouth
+Brian Fahs
+Michael Garland
+Michael Goldfarb
+Mostafa Hagog
+Fei Hu
+Alan Kaatz
+Tina Li
+Timmy Liu
+Wei Liu
+Tim Martin
+Duane Merrill
+Kevin Siu
+Markus Tavenrath
+John Tran
+Vicki Wang
+Junkai Wu
+Fung Xie
+Albert Xu
+Yang Xu
+Jack Yang
+Scott Yokim
+Xiuxia Zhang
+Nick Zhao
+## ACKNOWLEDGEMENTS +Girish Bharambe
+Luke Durant
+Carter Edwards
+Olivier Giroux
+Stephen Jones
+Rishkul Kulkarni
+Bryce Lelbach
+Joel McCormack
+Kyrylo Perelygin
+Sean Treichler
diff --git a/CUDA.cmake b/CUDA.cmake index ff6a6afc8a..7e91adb88d 100644 --- a/CUDA.cmake +++ b/CUDA.cmake @@ -1,4 +1,4 @@ -# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without @@ -26,49 +26,46 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. -if(CUDA_COMPILER MATCHES "[Cc]lang") - set(CUTLASS_NATIVE_CUDA_INIT ON) -elseif(CMAKE_VERSION VERSION_LESS 3.12.4) - set(CUTLASS_NATIVE_CUDA_INIT OFF) -else() - set(CUTLASS_NATIVE_CUDA_INIT ON) +if (CUDA_COMPILER MATCHES "[Cc]lang") + message(WARNING "CUDA_COMPILER flag is deprecated, set CMAKE_CUDA_COMPILER to desired compiler executable.") + set(__CLANG_DEVICE_COMPILATION_REQUESTED ON) +elseif(CUDA_COMPILER) + message(WARNING "Deprecated flag CUDA_COMPILER used with unknown argument ${CUDA_COMPILER}, ignoring.") endif() -set(CUTLASS_NATIVE_CUDA ${CUTLASS_NATIVE_CUDA_INIT} CACHE BOOL "Utilize the CMake native CUDA flow") - -if(NOT DEFINED ENV{CUDACXX} AND NOT DEFINED ENV{CUDA_BIN_PATH} AND DEFINED ENV{CUDA_PATH}) - # For backward compatibility, allow use of CUDA_PATH. - set(ENV{CUDACXX} $ENV{CUDA_PATH}/bin/nvcc) +if (__CLANG_DEVICE_COMPILATION_REQUESTED AND NOT DEFINED CMAKE_CUDA_COMPILER) + set(CMAKE_CUDA_COMPILER clang++) # We will let the system find Clang or error out endif() -if(CUTLASS_NATIVE_CUDA) +enable_language(CUDA) +find_package(CUDAToolkit REQUIRED) - enable_language(CUDA) - - if(NOT CUDA_VERSION) - set(CUDA_VERSION ${CMAKE_CUDA_COMPILER_VERSION}) - endif() - if(NOT CUDA_TOOLKIT_ROOT_DIR) - get_filename_component(CUDA_TOOLKIT_ROOT_DIR "${CMAKE_CUDA_COMPILER}/../.." ABSOLUTE) - endif() +if(NOT CUDA_VERSION) + # For backward compatibility with older CMake code. + set(CUDA_VERSION ${CUDAToolkit_VERSION}) + set(CUDA_VERSION_MAJOR ${CUDAToolkit_VERSION_MAJOR}) + set(CUDA_VERSION_MINOR ${CUDAToolkit_VERSION_MINOR}) +endif() +if(NOT CUDA_TOOLKIT_ROOT_DIR) + # In some scenarios, such as clang device compilation, the toolkit root may not be set, so we + # force it here to the nvcc we found via the CUDAToolkit package. + get_filename_component(CUDA_TOOLKIT_ROOT_DIR "${CUDAToolkit_NVCC_EXECUTABLE}/../.." ABSOLUTE) +endif() +if (CMAKE_CUDA_COMPILER_ID MATCHES "(nvcc|[Nn][Vv][Ii][Dd][Ii][Aa])") + set(CUTLASS_NVCC_DEVICE_COMPILE ON CACHE BOOL "Using nvcc tools for device compilation") +elseif (CMAKE_CUDA_COMPILER_ID MATCHES "[Cc]lang") + set(CUTLASS_CLANG_DEVICE_COMPILE ON CACHE BOOL "Using Clang tools for device compilation") else() + message(FATAL_ERROR "Uknown device-side compiler ${CMAKE_CUDA_COMPILER_ID} found. Set CMAKE_CUDA_COMPILER to either nvcc or clang++.") +endif() - find_package(CUDA REQUIRED) - # We workaround missing variables with the native flow by also finding the CUDA toolkit the old way. - - if(NOT CMAKE_CUDA_COMPILER_VERSION) - set(CMAKE_CUDA_COMPILER_VERSION ${CUDA_VERSION}) - endif() - +if (CUTLASS_CLANG_DEVICE_COMPILE AND CMAKE_VERSION VERSION_LESS_EQUAL "3.30") + message(FATAL_ERROR "Clang device compilation for CUTLASS requires CMake 3.30 or higher.") endif() if (CUDA_VERSION VERSION_LESS 9.2) - message(FATAL_ERROR "CUDA 9.2+ Required, Found ${CUDA_VERSION}.") -endif() -if(NOT CUTLASS_NATIVE_CUDA OR CUDA_COMPILER MATCHES "[Cc]lang") - set(CMAKE_CUDA_COMPILER ${CUDA_TOOLKIT_ROOT_DIR}/bin/nvcc) - message(STATUS "CUDA Compiler: ${CMAKE_CUDA_COMPILER}") + message(FATAL_ERROR "CUDA 9.2+ required, found ${CUDA_VERSION}.") endif() find_library( @@ -76,11 +73,12 @@ find_library( PATHS ${CUDA_TOOLKIT_ROOT_DIR} PATH_SUFFIXES + lib/x86_64-linux-gnu lib/x64 lib64 lib NO_DEFAULT_PATH - # We aren't going to search any system paths. We want to find the runtime + # We aren't going to search any system paths. We want to find the runtime # in the CUDA toolkit we're building against. ) @@ -95,10 +93,10 @@ if(NOT TARGET cudart AND CUDART_LIBRARY) # from the PATH search. else() add_library(cudart SHARED IMPORTED GLOBAL) - endif() + endif() add_library(nvidia::cudart ALIAS cudart) - + set_property( TARGET cudart PROPERTY IMPORTED_LOCATION @@ -120,13 +118,14 @@ find_library( PATHS ${CUDA_TOOLKIT_ROOT_DIR} PATH_SUFFIXES + lib/x86_64-linux-gnu lib/x64 lib64 lib lib64/stubs lib/stubs NO_DEFAULT_PATH - # We aren't going to search any system paths. We want to find the runtime + # We aren't going to search any system paths. We want to find the runtime # in the CUDA toolkit we're building against. ) @@ -141,10 +140,10 @@ if(NOT TARGET cuda_driver AND CUDA_DRIVER_LIBRARY) # from the PATH search. else() add_library(cuda_driver SHARED IMPORTED GLOBAL) - endif() + endif() add_library(nvidia::cuda_driver ALIAS cuda_driver) - + set_property( TARGET cuda_driver PROPERTY IMPORTED_LOCATION @@ -170,7 +169,7 @@ find_library( lib64 lib NO_DEFAULT_PATH - # We aren't going to search any system paths. We want to find the runtime + # We aren't going to search any system paths. We want to find the runtime # in the CUDA toolkit we're building against. ) @@ -185,10 +184,10 @@ if(NOT TARGET nvrtc AND NVRTC_LIBRARY) # from the PATH search. else() add_library(nvrtc SHARED IMPORTED GLOBAL) - endif() - + endif() + add_library(nvidia::nvrtc ALIAS nvrtc) - + set_property( TARGET nvrtc PROPERTY IMPORTED_LOCATION @@ -209,16 +208,6 @@ include_directories(SYSTEM ${CUDA_INCLUDE_DIRS}) # Some platforms (e.g. Visual Studio) don't add the CUDA include directories to the system include # paths by default, so we add it explicitly here. -function(cutlass_correct_source_file_language_property) - if(CUDA_COMPILER MATCHES "[Cc]lang") - foreach(File ${ARGN}) - if(File MATCHES ".*\.cu$") - set_source_files_properties(${File} PROPERTIES LANGUAGE CXX) - endif() - endforeach() - endif() -endfunction() - if (MSVC OR CUTLASS_LIBRARY_KERNELS MATCHES "all") set(CUTLASS_UNITY_BUILD_ENABLED_INIT ON) else() @@ -226,7 +215,14 @@ else() endif() set(CUTLASS_UNITY_BUILD_ENABLED ${CUTLASS_UNITY_BUILD_ENABLED_INIT} CACHE BOOL "Enable combined source compilation") -set(CUTLASS_UNITY_BUILD_BATCH_SIZE 16 CACHE STRING "Batch size for unified source files") + +if (MSVC) + set(CUTLASS_UNITY_BUILD_BATCH_SIZE_INIT 8) +else() + set(CUTLASS_UNITY_BUILD_BATCH_SIZE_INIT 16) +endif() + +set(CUTLASS_UNITY_BUILD_BATCH_SIZE ${CUTLASS_UNITY_BUILD_BATCH_SIZE_INIT} CACHE STRING "Batch size for unified source files") function(cutlass_unify_source_files TARGET_ARGS_VAR) @@ -239,15 +235,19 @@ function(cutlass_unify_source_files TARGET_ARGS_VAR) message(FATAL_ERROR "TARGET_ARGS_VAR parameter is required") endif() + if (NOT DEFINED __BATCH_SOURCES) + set(__BATCH_SOURCES ON) + endif() + if (__BATCH_SOURCES AND NOT DEFINED __BATCH_SIZE) set(__BATCH_SIZE ${CUTLASS_UNITY_BUILD_BATCH_SIZE}) endif() - if (CUTLASS_UNITY_BUILD_ENABLED AND DEFINED __BATCH_SIZE AND __BATCH_SIZE GREATER 1) + if (CUTLASS_UNITY_BUILD_ENABLED AND __BATCH_SOURCES AND __BATCH_SIZE GREATER 1) set(CUDA_FILE_ARGS) set(TARGET_SOURCE_ARGS) - + foreach(ARG ${__UNPARSED_ARGUMENTS}) if(${ARG} MATCHES ".*\.cu$") list(APPEND CUDA_FILE_ARGS ${ARG}) @@ -255,7 +255,7 @@ function(cutlass_unify_source_files TARGET_ARGS_VAR) list(APPEND TARGET_SOURCE_ARGS ${ARG}) endif() endforeach() - + list(LENGTH CUDA_FILE_ARGS NUM_CUDA_FILE_ARGS) while(NUM_CUDA_FILE_ARGS GREATER 0) list(SUBLIST CUDA_FILE_ARGS 0 ${__BATCH_SIZE} CUDA_FILE_BATCH) @@ -287,23 +287,20 @@ function(cutlass_unify_source_files TARGET_ARGS_VAR) endfunction() function(cutlass_add_library NAME) - set(options) + set(options SKIP_GENCODE_FLAGS) set(oneValueArgs EXPORT_NAME) set(multiValueArgs) cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) cutlass_unify_source_files(TARGET_SOURCE_ARGS ${__UNPARSED_ARGUMENTS}) - - if(CUTLASS_NATIVE_CUDA OR CUDA_COMPILER MATCHES "clang") - cutlass_correct_source_file_language_property(${TARGET_SOURCE_ARGS}) - add_library(${NAME} ${TARGET_SOURCE_ARGS}) - else() - set(CUDA_LINK_LIBRARIES_KEYWORD PRIVATE) - cuda_add_library(${NAME} ${TARGET_SOURCE_ARGS}) - endif() + + add_library(${NAME} ${TARGET_SOURCE_ARGS} "") cutlass_apply_standard_compile_options(${NAME}) - cutlass_apply_cuda_gencode_flags(${NAME}) + + if (NOT __SKIP_GENCODE_FLAGS) + cutlass_apply_cuda_gencode_flags(${NAME}) + endif() target_compile_features( ${NAME} @@ -311,6 +308,14 @@ function(cutlass_add_library NAME) cxx_std_11 ) + get_target_property(TARGET_TYPE ${NAME} TYPE) + + if (TARGET_TYPE MATCHES "SHARED") + set_target_properties(${NAME} PROPERTIES CUDA_RUNTIME_LIBRARY Shared) + elseif(TARGET_TYPE MATCHES "STATIC") + set_target_properties(${NAME} PROPERTIES CUDA_RUNTIME_LIBRARY Static) + endif() + if(__EXPORT_NAME) add_library(nvidia::cutlass::${__EXPORT_NAME} ALIAS ${NAME}) set_target_properties(${NAME} PROPERTIES EXPORT_NAME ${__EXPORT_NAME}) @@ -321,20 +326,23 @@ endfunction() function(cutlass_add_executable NAME) set(options) - set(oneValueArgs) + set(oneValueArgs CUDA_RUNTIME_LIBRARY) set(multiValueArgs) cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) - cutlass_unify_source_files(TARGET_SOURCE_ARGS ${__UNPARSED_ARGUMENTS}) + if (NOT DEFINED __CUDA_RUNTIME_LIBRARY) + set(__CUDA_RUNTIME_LIBRARY Shared) + endif() - if(CUTLASS_NATIVE_CUDA OR CUDA_COMPILER MATCHES "clang") - cutlass_correct_source_file_language_property(${TARGET_SOURCE_ARGS}) - add_executable(${NAME} ${TARGET_SOURCE_ARGS}) - else() - set(CUDA_LINK_LIBRARIES_KEYWORD PRIVATE) - cuda_add_executable(${NAME} ${TARGET_SOURCE_ARGS}) + set(__CUDA_RUNTIME_LIBRARY_ALLOWED None Shared Static) + if (NOT __CUDA_RUNTIME_LIBRARY IN_LIST __CUDA_RUNTIME_LIBRARY_ALLOWED) + message(FATAL_ERROR "CUDA_RUNTIME_LIBRARY value '${__CUDA_RUNTIME_LIBRARY}' is not in allowed list of '${__CUDA_RUNTIME_LIBRARY_ALLOWED}'") endif() + cutlass_unify_source_files(TARGET_SOURCE_ARGS ${__UNPARSED_ARGUMENTS}) + + add_executable(${NAME} ${TARGET_SOURCE_ARGS}) + cutlass_apply_standard_compile_options(${NAME}) cutlass_apply_cuda_gencode_flags(${NAME}) @@ -344,6 +352,8 @@ function(cutlass_add_executable NAME) cxx_std_11 ) + set_target_properties(${NAME} PROPERTIES CUDA_RUNTIME_LIBRARY ${__CUDA_RUNTIME_LIBRARY}) + endfunction() function(cutlass_target_sources NAME) @@ -354,7 +364,6 @@ function(cutlass_target_sources NAME) cmake_parse_arguments(_ "${options}" "${oneValueArgs}" "${multiValueArgs}" ${ARGN}) cutlass_unify_source_files(TARGET_SOURCE_ARGS ${__UNPARSED_ARGUMENTS}) - cutlass_correct_source_file_language_property(${TARGET_SOURCE_ARGS}) target_sources(${NAME} ${TARGET_SOURCE_ARGS}) endfunction() diff --git a/LICENSE.txt b/LICENSE.txt index d9219ec9b9..525500841e 100644 --- a/LICENSE.txt +++ b/LICENSE.txt @@ -1,4 +1,4 @@ -Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. SPDX-License-Identifier: BSD-3-Clause Redistribution and use in source and binary forms, with or without diff --git a/PUBLICATIONS.md b/PUBLICATIONS.md index 5c36742a88..ba0ef4cff8 100644 --- a/PUBLICATIONS.md +++ b/PUBLICATIONS.md @@ -1,16 +1,56 @@ # Publications Using Cutlass +## 2024 + +- ["ShadowKV: KV Cache in Shadows for High-Throughput Long-Context LLM Inference"](https://arxiv.org/abs/2410.21465). Hanshi Sun, Li-Wen Chang, Wenlei Bao, Size Zheng, Ningxin Zheng, Xin Liu, Harry Dong, Yuejie Chi, Beidi Chen. _arXiv_, October 2024. + +- ["FLUX: Fast Software-based Communication Overlap On GPUs Through Kernel Fusion"](https://arxiv.org/abs/2406.06858). Li-Wen Chang, Wenlei Bao, Qi Hou, Chengquan Jiang, Ningxin Zheng, Yinmin Zhong, Xuanrun Zhang, Zuquan Song, Chengji Yao, Ziheng Jiang, Haibin Lin, Xin Jin, Xin Liu. _arXiv_, June 2024. + +- ["EVT: Accelerating Deep Learning Training with Epilogue Visitor Tree"](https://dl.acm.org/doi/10.1145/3620666.3651369). Zhaodong Chen, Andrew Kerr, Richard Cai, Jack Kosaian, Haicheng Wu, Yufei Ding, and Yuan Xie. _Proceedings of the 29th ACM International Conference on Architectural Support for Programming Languages and Operating Systems_, April 2024. + +- ["Faster Neighborhood Attention: Reducing the O(n^2) Cost of Self Attention at the Threadblock Level"](https://arxiv.org/abs/2403.04690). Ali Hassani, Wen-Mei Hwu, Humphrey Shi. _arXiv_, March 2024. + +## 2023 + +- ["A Case Study in CUDA Kernel Fusion: Implementing FlashAttention-2 on NVIDIA Hopper Architecture using the CUTLASS Library"](https://arxiv.org/abs/2312.11918). Ganesh Bikshandi, Jay Shah. _arXiv_, December 2023. + +- ["Benchmarking GPU Tensor Cores on General Matrix Multiplication Kernels through CUTLASS"](https://www.mdpi.com/2076-3417/13/24/13022). Xuanteng Huang, Xianwei Zhang, Panfei Yang, Nong Xiao. _Journal of Applied Sciences_, December 2023. + +- ["A Speed Odyssey for Deployable Quantization of LLMs"](https://arxiv.org/abs/2311.09550). Qingyuan Li, Ran Meng, Yiduo Li, Bo Zhang, Liang Li, Yifan Lu, Xiangxiang Chu, Yerui Sun, Yuchen Xie. _arXiv_, November 2023. + +- ["FlashAttention-2: Faster Attention with Better Parallelism and Work Partitioning"](https://arxiv.org/abs/2307.08691). Tri Dao. _Technical Report_, July 2023. + +- ["MegaBlocks: Efficient Sparse Training with Mixture-of-Experts"](https://arxiv.org/abs/2211.15841). Trevor Gale, Deepak Narayanan, Cliff Young, Matei Zaharia. _Proceedings of the Sixth Machine Learning and Systems_, May 2023. + +- ["ByteTransformer: A High-Performance Transformer Boosted for Variable-Length Inputs"](https://arxiv.org/abs/2210.03052). Yujia Zhai, Chengquan Jiang, Leyuan Wang, Xiaoying Jia, Shang Zhang, Zizhong Chen, Xin Liu, Yibo Zhu. _Proceedings of the 37th IEEE International Parallel & Distributed Processing Symposium (Best Paper)_, May 2023. + +- ["A Framework for Fine-Grained Synchronization of Dependent GPU Kernels"](https://arxiv.org/abs/2305.13450). Abhinav Jangda, Saeed Maleki, Maryam Mehri Dehnavi, Madan Musuvathi, Olli Saarikivi. _Computing Research Repository_, May 2023. + +- ["Graphene: An IR for Optimized Tensor Computations on GPUs"](https://dl.acm.org/doi/pdf/10.1145/3582016.3582018). Hagedorn, Bastian, Bin Fan, Hanfeng Chen, Cris Cecka, Michael Garland, Vinod Grover. _Proceedings of the 28th ACM International Conference on Architectural Support for Programming Languages and Operating Systems_, March 2023. + +- ["Mixed Precision Post Training Quantization of Neural Networks with Sensitivity Guided Search"](https://arxiv.org/abs/2302.01382). Clemens JS Schaefer, Elfie Guo, Caitlin Stanton, Xiaofan Zhang, Tom Jablin, Navid Lambert-Shirzad, Jian Li, Chiachen Chou, Siddharth Joshi, Yu Emma Wang. _arXiv_, Feburary 2023. + +- ["Dynamic N:M Fine-Grained Structured Sparse Attention Mechanism"](https://dl.acm.org/doi/abs/10.1145/3572848.3577500). Zhaodong Chen, Zheng Qu, Yuying Quan, Liu Liu, Yufei Ding, Yuan Xie. _Proceedings of the 28th ACM SIGPLAN Annual Symposium on Principles and Practice of Parallel Programming_, Feburary 2023. + +- ["Stream-K: Work-centric Parallel Decomposition for Dense Matrix-Matrix Multiplication on the GPU"](https://arxiv.org/abs/2301.03598). Muhammad Osama, Duane Merrill, Cris Cecka, Michael Garland, John D. Owens. _arXiv_, January 2023. + ## 2022 +- ["GPU Load Balancing"](https://arxiv.org/abs/2212.08964). Muhammad Osama. _Doctoral dissertation, University of California, Davis_, December 2022. + +- ["Who Says Elephants Can't Run: Bringing Large Scale MoE Models into Cloud Scale Production"](https://arxiv.org/abs/2211.10017). Young Jin Kim, Rawn Henry, Raffy Fahim, Hany Hassan Awadalla. _Proceedings of the Third Workshop on Simple and Efficient Natural Language Processing_, December 2022. + - ["Bolt: Bridging the Gap between Auto-tuners and Hardware-native Performance"](https://arxiv.org/abs/2110.15238). Jiarong Xing, Leyuan Wang, Shang Zhang, Jack Chen, Ang Chen, Yibo Zhu. _Proceedings of the 5th MLSys Conference_, August 2022. - ["Recovering single precision accuracy from Tensor Cores while surpassing the FP32 theoretical peak performance"](https://arxiv.org/abs/2203.03341). Hiroyuki Ootomo, Rio Yokota. _International Journal of High Performance Computing_, March 2022. +- ["Breaking the Computation and Communication Abstraction Barrier in Distributed Machine Learning Workloads"](https://arxiv.org/abs/2105.05720). Abhinav Jangda, Jun Huang, Guodong Liu, Amir Hossein Nodehi Sabet, Saeed Maleki, Youshan Miao, Madanlal Musuvathi, Todd Mytkowicz, Olli Sarikivi. _Proceedings of the 27th ACM International Conference on Architectural Support for Programming Languages and Operating Systems_, February 2022. + ## 2021 - ["Arithmetic-intensity-guided fault tolerance for neural network inference on GPUs"](https://dl.acm.org/doi/abs/10.1145/3458817.3476184). Jack Kosaian, K. V. Rashmi. _Proceedings of the International Conference for High Performance Computing, Networking, Storage and Analysis_, November 2021. -- ["Real-time Neural Radiance Caching for Path Tracing"](https://d1qx31qr3h6wln.cloudfront.net/publications/paper_4.pdf). Thomas Muller, Fabrice Rousselle, Jan Novak, Alex Keller. _ACM Trans. Graph._, August 2021. +- ["Real-time Neural Radiance Caching for Path Tracing"](https://dl.acm.org/doi/abs/10.1145/3450626.3459812). Thomas Muller, Fabrice Rousselle, Jan Novak, Alex Keller. _ACM Trans. Graph._, August 2021. ## 2020 diff --git a/README.md b/README.md index 78ca725c1a..e61335f240 100644 --- a/README.md +++ b/README.md @@ -1,153 +1,195 @@ -![ALT](/media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition") +![ALT](./media/images/gemm-hierarchy-with-epilogue-no-labels.png "Complete CUDA GEMM decomposition") -# CUTLASS 2.9 +# CUTLASS 3.6.0 -_CUTLASS 2.9 - April 2022_ +_CUTLASS 3.6.0 - October 2024_ CUTLASS is a collection of CUDA C++ template abstractions for implementing -high-performance matrix-multiplication (GEMM) and related computations at all levels +high-performance matrix-matrix multiplication (GEMM) and related computations at all levels and scales within CUDA. It incorporates strategies for hierarchical decomposition and data movement similar to those used to implement cuBLAS and cuDNN. CUTLASS decomposes these "moving parts" into reusable, modular software components abstracted by C++ template -classes. These thread-wide, warp-wide, block-wide, and device-wide primitives can be specialized -and tuned via custom tiling sizes, data types, and other algorithmic policy. The -resulting flexibility simplifies their use as building blocks within custom kernels -and applications. +classes. Primitives for different levels of a conceptual parallelization hierarchy +can be specialized and tuned via custom tiling sizes, data types, +and other algorithmic policy. The resulting flexibility simplifies their use +as building blocks within custom kernels and applications. To support a wide variety of applications, CUTLASS provides extensive support for mixed-precision computations, providing specialized data-movement and multiply-accumulate abstractions for half-precision floating point (FP16), BFloat16 (BF16), Tensor Float 32 (TF32), -single-precision floating point (FP32), double-precision floating -point (FP64) types, integer data types (4b and 8b), and binary data types (1b). -CUTLASS demonstrates warp-synchronous matrix multiply operations -targeting the programmable, high-throughput _Tensor Cores_ implemented by -NVIDIA's Volta, Turing, and Ampere architectures. +single-precision floating point (FP32), +[FP32 emulation via tensor core instruction](./examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm), +double-precision floating +point (FP64) types, integer data types (4b and 8b), and binary data types (1b). +CUTLASS demonstrates warp-synchronous matrix multiply operations +targeting the programmable, high-throughput _Tensor Cores_ implemented by +NVIDIA's Volta, Turing, Ampere, and Hopper architectures. -CUTLASS implements high-performance Convolution via the implicit GEMM algorithm. -Implicit GEMM is the formulation of a convolution operation as a GEMM thereby taking advantage of -CUTLASS's modular GEMM pipeline. -This allows CUTLASS to build convolutions by reusing highly optimized warp-wide GEMM components and below. +See the [Quick Start Guide](./media/docs/quickstart.md) to get started quickly. -See the [Quick Start Guide](/media/docs/quickstart.md) to get started quickly. - -See the [functionality listing](/media/docs/functionality.md) for the list of operations +See the [functionality listing](./media/docs/functionality.md) for the list of operations supported at each level of the execution model hierarchy. -# What's New in CUTLASS 2.9 - -CUTLASS 2.9 is an update to CUTLASS adding: -- [First layer Convolution kernels](/test/unit/conv/device/conv2d_fprop_fixed_channels_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_sm80.cu) specialized for small channel counts and reduced alignment -- [BLAS3](https://docs.nvidia.com/cuda/cublas/index.html#cublas-level-3-function-reference) operators accelerated by Tensor Cores - - [SYRK](/test/unit/gemm/device/syrk_f32n_f32t_tensor_op_fast_f32_sm80.cu), [HERK](/test/unit/gemm/device/herk_cf32h_cf32n_tensor_op_fast_f32_sm80.cu), - - [SYR2K](/test/unit/gemm/device/syr2k_f32n_f32n_tensor_op_fast_f32_sm80.cu), [HER2K](/test/unit/gemm/device/her2k_cf32h_cf32n_tensor_op_fast_f32_sm80.cu), - - [Out-of-place TRMM](/test/unit/gemm/device/trmm_f32n_f32t_f32t_tensor_op_fast_f32_ls_sm80.cu), and - - [SYMM](/test/unit/gemm/device/symm_f32n_f32n_tensor_op_fast_f32_ls_sm80.cu), [HEMM](/test/unit/gemm/device/hemm_cf32h_cf32n_tensor_op_fast_f32_ls_sm80.cu) -- [CUTLASS Python](/examples/40_cutlass_py) demonstrating JIT compilation of CUTLASS kernels and a Python-based runtime using [CUDA Python](https://developer.nvidia.com/cuda-python) -- [GEMM + Softmax example](/examples/35_gemm_softmax) -- [Gather and Scatter Fusion with GEMM](/examples/36_gather_scatter_fusion) can gather inputs and scatters outputs based on indices vectors in the same GEMM kernel. -- [Back-to-back GEMM/CONV](examples/13_two_tensor_op_fusion) fully supports buffering the first GEMM/CONV results in the shared memory for the latter one to use. Bias Vector add is also supported in the first GEMM/CONV. -- [Transposed Convolution](/examples/34_transposed_conv2d) (a.k.a Deconvolution) support which reuses Dgrad implementation. -- [Utility functions](/tools/util/include/cutlass/util) that can pad NHWC and convert between NCHW and NHWC. -- [Small alignment implicit gemm](https://github.com/NVIDIA/cutlass/issues/242) support for Fprop/Dgrad/Wgrad so that padding is no longer mandated to use tensor cores. -- Epilogue enhancement with performance improvement, more activation functions, and more fusion patterns. -- [Group GEMM](/examples/24_gemm_grouped) thread block number calculation fix. -- Optimal performance using [CUDA 11.7](https://developer.nvidia.com/cuda-downloads) -- [Parallel GEMM splitk](https://github.com/NVIDIA/cutlass/pull/277) support in the CUTLASS profiler. -- Updates and bugfixes from the community (thanks!) -- **Deprecation announcement:** CUTLASS plans to deprecate the following: - - Maxwell and Pascal GPU architectures - - Ubuntu 16.04 - - CUDA 10.2 +CUTLASS 3.0 introduced a new core library, CuTe, to describe and manipulate tensors of threads and data. +CuTe is a collection of C++ CUDA template abstractions for defining and operating on hierarchically multidimensional layouts of threads and data. CuTe provides `Layout` and `Tensor` objects that compactly package the type, shape, memory space, and layout of data, while performing the complicated indexing for the user. This lets programmers focus on the logical descriptions of their algorithms while CuTe does the mechanical bookkeeping for them. With these tools, we can quickly design, implement, and modify all dense linear algebra operations. + +The core abstractions of CuTe are hierarchically multidimensional layouts which can be composed with data arrays to represent tensors. The representation of layouts is powerful enough to represent nearly everything we need to implement efficient dense linear algebra. Layouts can also be combined and manipulated via functional composition, on which we build a large set of common operations such as tiling and partitioning. + +CUTLASS 3.0 and beyond adopts CuTe throughout the GEMM hierarchy in its templates. This greatly simplifies the design +and improves code composability and readability. More documentation specific to CuTe can be found in its [dedicated documentation directory](./media/docs/cute/00_quickstart.md). + +In addition to GEMMs, CUTLASS implements high-performance convolution via the implicit GEMM algorithm. Implicit GEMM is the formulation of a convolution operation as a GEMM thereby taking advantage of CUTLASS's modular GEMM pipeline. This allows CUTLASS to build convolutions by reusing highly-optimized GEMM components. + + +# What's New in CUTLASS 3.6 + +CUTLASS 3.6.0 is an update to CUTLASS adding: + +- [Hopper structured sparse GEMM](./examples/62_hopper_sparse_gemm/62_hopper_sparse_gemm.cu). + + [FP16](./test/unit/gemm/device/sm90_sparse_gemm_f16_f16_f32_tensor_op_f32.cu) + + [FP8](./test/unit/gemm/device/sm90_sparse_gemm_f8_f8_f32_tensor_op_f32.cu) + + [INT8](./test/unit/gemm/device/sm90_sparse_gemm_s8_s8_s32_tensor_op_s32.cu) + + [TF32](./test/unit/gemm/device/sm90_sparse_gemm_tf32_tf32_f32_tensor_op_f32.cu) +- A refactor to the CUTLASS 3.x convolution `kernel::ConvUniversal` [API](./include/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp) to bring it in line with `gemm::GemmUniversal`. Now the 3.x convolution API is no longer considered as a beta API. +- [An improved mixed input GEMM](./examples/55_hopper_mixed_dtype_gemm/README.md) and a [lookup table implementation](./examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu) for `INT4`x`FP8` scale-only mode. +- [EVT nodes for Top-K selection and softmax](./include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp) and [GEMM example using those](./examples/61_hopper_gemm_with_topk_and_softmax/61_hopper_gemm_with_topk_and_softmax.cu). +- [Programmatic Dependent Launch](./include/cutlass/arch/grid_dependency_control.h) (PDL) that leverages a new Hopper feature to speedup two back-to-back kernels, and its corresponding [documentations](./media/docs/dependent_kernel_launch.md). +- [A new debugging tool, synclog](./include/cutlass/arch/synclog.hpp), for dumping out all synchronization events from within a kernel to a file. Please see [synclog documentation](./media/docs/utilities.md#debugging-asynchronous-kernels-with-cutlasss-built-in-synclog-tool) for details. +- A new TMA-enabled [epilogue](./include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp) for grouped GEMM that brings significant performance improvement, as well as its EVT support. +- A SIMT-enabled pointer-array [epilogue](./include/cutlass/epilogue/collective/sm70_epilogue_vectorized_array.hpp). +- A new [Ping-Pong kernel schedule for Grouped GEMM](./include/cutlass/gemm/kernel/sm90_gemm_array_tma_warpspecialized_pingpong.hpp) and some other optimizations. +- [A new instantiation strategy for CUTLASS profiler kernels](./python/cutlass_library/sm90_shapes.py) along with [improved documentation for instantiation level in CUTLASS profiler](./media/docs/profiler.md#instantiating-more-kernels-with-hopper). +- A new hardware support for comparisons and computations of [`cutlass::bfloat16_t`](./include/cutlass/bfloat16.h) +- Fixed use of isnan on Windows for [`half_t`](./test/unit/core/functional.cu). + +Minimum requirements: + +- Architecture: Volta +- Compiler: Must support at least C++17 +- CUDA Toolkit version: 11.4 + +Starting from CUTLASS 3.0, CUTLASS removed support for the following: + +- Maxwell and Pascal GPU architectures +- Ubuntu 16.04 +- CUDA 10.2 +- C++ language versions less than 17. **See the [CHANGELOG](CHANGELOG.md) for a detailed listing of releases and updates.** # Performance -

+

+

CUTLASS primitives are very efficient. When used to construct device-wide GEMM kernels, -they exhibit performance comparable to cuBLAS for scalar GEMM -computations. The above figure shows CUTLASS performance relative to cuBLAS -for large matrix dimensions on an [NVIDIA A100](https://www.nvidia.com/en-us/data-center/a100/), -an [NVIDIA A2](https://www.nvidia.com/en-us/data-center/products/a2/), -an [NVIDIA TitanV](https://www.nvidia.com/en-us/titan/titan-v/), -and an [NVIDIA GeForce 2080 Ti](https://www.nvidia.com/en-us/geforce/graphics-cards/rtx-2080-ti/) -compiled with the [CUDA 11.5 Toolkit](https://developer.nvidia.com/cuda-downloads). Tensor Core operations are implemented using CUDA's -[mma instruction](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma). +they exhibit peak performance comparable to cuBLAS for scalar GEMM +computations. The above figure shows the continual CUTLASS performance improvements +on an [NVIDIA H100](https://www.nvidia.com/en-us/data-center/h100/) (NVIDIA Hopper architecture) since +CUTLASS 3.1. +CUTLASS 3.5.1 was compiled with the [CUDA 12.5u1 Toolkit](https://developer.nvidia.com/cuda-downloads). +Tensor Core operations are implemented using CUDA's +[mma](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma) and +[wgmma](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#asynchronous-warpgroup-level-matrix-instructions) instructions. -

+

When using CUTLASS building blocks to construct device-wide implicit gemm (Fprop, Dgrad, and Wgrad) kernels, CUTLASS performance is also comparable to cuDNN when running Resnet-50 layers on an [NVIDIA A100](https://www.nvidia.com/en-us/data-center/a100/) -as shown in the above figure. Tensor Core operations are still implemented using CUDA's +as shown in the above figure. Tensor Core operations are implemented using CUDA's [mma instruction](https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-instructions-mma). # Compatibility -CUTLASS requires a C++11 host compiler and -performs best when built with the [**CUDA 11.6u2 Toolkit**](https://developer.nvidia.com/cuda-toolkit). -It is also compatible with CUDA 11.0, CUDA 11.1, CUDA 11.2, CUDA 11.3, CUDA 11.4, and CUDA 11.5. +CUTLASS requires a C++17 host compiler and +performs best when built with the [**CUDA 12.4 Toolkit**](https://developer.nvidia.com/cuda-downloads). +It is also compatible with CUDA 11.4, CUDA 11.5, CUDA 11.6, CUDA 11.7, CUDA 11.8, CUDA 12.0, CUDA 12.1, CUDA 12.2.2, CUDA 12.3.1 and CUDA 12.3.2. +## Operating Systems We have tested the following environments. |**Operating System** | **Compiler** | |-----------------|----------| -| Windows 10 | Microsoft Visual Studio 2015| -| | Microsoft Visual Studio 2017| -| | Microsoft Visual Studio 2019| -| Ubuntu 18.04 | GCC 7.5.0 | +| Ubuntu 18.04 | GCC 7.5.0 | | Ubuntu 20.04 | GCC 10.3.0 | -| Ubuntu 21.04 | GCC 11.2.0 | +| Ubuntu 22.04 | GCC 11.2.0 | +| Ubuntu 22.04 | Clang 10.0.0 | +| Ubuntu 22.04 | Clang 14.0.6 | +| Ubuntu 22.04 | Clang 17.0.6 | +| Windows 10.0 | Visual Studio 2019 v16.11.27 | + +Note: GCC 8.5.0 has known regressions regarding fold expressions and overloaded operators. Using GCC 7.5.0 or (preferred) GCC >= 9 is recommended. + +## Hardware +CUTLASS runs successfully on the following NVIDIA GPUs, and it is expected to be efficient on Volta, Turing, Ampere, Ada, and Hopper architecture based NVIDIA GPUs. + +|**GPU**|**CUDA Compute Capability**|**Minimum CUDA Toolkit Required by CUTLASS-3**| +|---|---|---| +|NVIDIA V100 Tensor Core GPU |7.0|11.4| +|NVIDIA TitanV |7.0|11.4| +|NVIDIA GeForce RTX 2080 TI, 2080, 2070 |7.5|11.4| +|NVIDIA T4 |7.5|11.4| +|NVIDIA A100 Tensor Core GPU |8.0|11.4| +|NVIDIA A10 |8.6|11.4| +|NVIDIA GeForce RTX 3090 |8.6|11.4| +|NVIDIA GeForce RTX 4090 |8.9|11.8| +|NVIDIA L40 |8.9|11.8| +|NVIDIA H100 Tensor Core GPU |9.0|11.8| -Additionally, CUTLASS may be built with clang. -See [these instructions](media/docs/quickstart.md#clang) for more details. +## Target Architecture -CUTLASS runs successfully on the following NVIDIA GPUs, and it is expected to be efficient on -any Volta-, Turing-, or NVIDIA Ampere- architecture NVIDIA GPU. +In general, PTX code generated for one target architecture can be run on future architectures (i.e., it is forward compatible). However, CUDA 12.0 introduced the concept of "architecture-accelerated features" whose PTX does not have forward compatibility guarantees. Several Hopper PTX instructions fall under this category of architecture-accelerated features, and thus require a `sm_90a` target architecture (note the "a" appended). For more details on this and other architecture-accelerated instructions, please refer to the [CUDA Documentation](https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#feature-availability). -|**GPU**|**CUDA Compute Capability**|**Minimum CUDA Toolkit**|**Minimum CUDA Toolkit Enabling Native Tensor Cores**| -|---|---|---|---| -|NVIDIA Tesla V100|7.0|9.2|10.1| -|NVIDIA TitanV|7.0|9.2|10.1| -|NVIDIA GeForce RTX 2080 TI, 2080, 2070|7.5|10.0|10.2| -|NVIDIA Tesla T4|7.5|10.0|10.2| -|NVIDIA A100|8.0|11.0|11.0| -|NVIDIA A10 |8.6|11.1|11.1| -|NVIDIA GeForce 3090|8.6|11.1|11.1| +The target architecture information is passed on to CUTLASS via the cmake flag `CUTLASS_NVCC_ARCHS`. In order to maximize performance on Hopper GH100, users are required to build CUTLASS with `90a` as the target architecture. If a user accidentally builds a kernel which uses SM90a features (e.g. Hopper Tensor Core Instructions), using the SM90 target (note the lack of "a"), with either CUDA Toolkit 12 or 11.8, the kernel is expected to fail with a runtime error. -For all GPUs, we recommend compiling with the [CUDA 11.6u2 Toolkit](https://developer.nvidia.com/cuda-toolkit) -for best performance. +``` +cmake .. -DCUTLASS_NVCC_ARCHS="90a" +``` + +Please refer to the [functionality documentation](./media/docs/functionality.md) for details on which kernels require which target architectures. # Documentation CUTLASS is described in the following documents and the accompanying [Doxygen documentation](https://nvidia.github.io/cutlass). -- [Quick Start Guide](/media/docs/quickstart.md) - build and run CUTLASS -- [Functionality](/media/docs/functionality.md) - summarizes functionality available in CUTLASS -- [Efficient GEMM in CUDA](media/docs/efficient_gemm.md) - describes how GEMM kernels may be implemented efficiently in CUDA -- [GEMM API](media/docs/gemm_api.md) - describes the CUTLASS GEMM model and C++ template concepts -- [Implicit GEMM Convolution](media/docs/implicit_gemm_convolution.md) - describes 2-D and 3-D convolution in CUTLASS -- [Code Organization](media/docs/code_organization.md) - describes the organization and contents of the CUTLASS project -- [Terminology](media/docs/terminology.md) - describes terms used in the code -- [Programming Guidelines](media/docs/programming_guidelines.md) - guidelines for writing efficient modern CUDA C++ -- [Fundamental types](media/docs/fundamental_types.md) - describes basic C++ classes used in CUTLASS to represent numeric quantities and arrays -- [Layouts](media/docs/layout.md) - describes layouts of matrices and tensors in memory -- [Tile Iterators](media/docs/tile_iterator_concept.md) - describes C++ concepts for iterating over tiles of matrices in memory -- [CUTLASS Profiler](media/docs/profiler.md) - command-line driven profiling application -- [CUTLASS Utilities](media/docs/utilities.md) - additional templates used to facilate rapid development - +- [Quick Start Guide](./media/docs/quickstart.md) - build and run CUTLASS +- [Functionality](./media/docs/functionality.md) - summarizes functionality available in CUTLASS +- [Efficient GEMM in CUDA](./media/docs/efficient_gemm.md) - describes how GEMM kernels may be implemented efficiently in CUDA +- [CUTLASS 3.x Design](./media/docs/cutlass_3x_design.md) - describes the CUTLASS 3.x design, its benefits, and how CuTe enables us to write much more composable components +- [GEMM API 3.x](./media/docs/gemm_api_3x.md) - describes the CUTLASS 3.x GEMM model and C++ template concepts +- [GEMM API 2.x](./media/docs/gemm_api.md) - describes the CUTLASS 2.x GEMM model and C++ template concepts +- [Implicit GEMM Convolution](./media/docs/implicit_gemm_convolution.md) - describes 2-D and 3-D convolution in CUTLASS +- [Code Organization](./media/docs/code_organization.md) - describes the organization and contents of the CUTLASS project +- [Terminology](./media/docs/terminology.md) - describes terms used in the code +- [Programming Guidelines](./media/docs/programming_guidelines.md) - guidelines for writing efficient modern CUDA C++ +- [Fundamental types](./media/docs/fundamental_types.md) - describes basic C++ classes used in CUTLASS to represent numeric quantities and arrays +- [Layouts](./media/docs/layout.md) - describes layouts of matrices and tensors in memory +- [Tile Iterators](./media/docs/tile_iterator_concept.md) - describes C++ concepts for iterating over tiles of matrices in memory +- [CUTLASS Profiler](./media/docs/profiler.md) - command-line driven profiling application +- [CUTLASS Utilities](./media/docs/utilities.md) - additional templates used to facilate rapid development +- [Dependent kernel launch](./media/docs/dependent_kernel_launch.md) - describes a new feature in Hopper which allows overlapping dependent +kernels in the same stream, and how it is used in CUTLASS. + +# Resources We have also described the structure of an efficient GEMM in our talk at the [GPU Technology Conference 2018](http://on-demand.gputechconf.com/gtc/2018/presentation/s8854-cutlass-software-primitives-for-dense-linear-algebra-at-all-levels-and-scales-within-cuda.pdf). + - [CUTLASS: Software Primitives for Dense Linear Algebra at All Levels and Scales within CUDA](https://www.nvidia.com/en-us/on-demand/session/gtcsiliconvalley2018-s8854/) + - [Developing CUDA Kernels to Push Tensor Cores to the Absolute Limit on NVIDIA A100](https://www.nvidia.com/en-us/on-demand/session/gtcsj20-s21745/) + - [Accelerating Convolution with Tensor Cores in CUTLASS](https://www.nvidia.com/en-us/on-demand/session/gtcspring21-s31883/) + - [Accelerating Backward Data Gradient by Increasing Tensor Core Utilization in CUTLASS](https://www.nvidia.com/en-us/on-demand/session/gtcspring22-s41996/) + - [CUTLASS: Python API, Enhancements, and NVIDIA Hopper](https://www.nvidia.com/en-us/on-demand/session/gtcfall22-a41131/) + # Building CUTLASS CUTLASS is a header-only template library and does not need to be built to be used by other projects. Client applications should target CUTLASS's `include/` directory in their include paths. -CUTLASS unit tests, examples, and utilities can be build with CMake starting version 3.12. +CUTLASS unit tests, examples, and utilities can be build with CMake. +The minimum version of CMake is given in the [Quickstart guide](./media/docs/quickstart.md). Make sure the `CUDACXX` environment variable points to NVCC in the CUDA Toolkit installed on your system. @@ -156,7 +198,8 @@ $ export CUDACXX=${CUDA_INSTALL_PATH}/bin/nvcc ``` Create a build directory within the CUTLASS project, then run CMake. By default CUTLASS will build kernels -for CUDA architecture versions 5.0, 6.0, 6.1, 7.0, 7.5, 8.0, and 8.6. To reduce compile time you can specify +for CUDA architecture versions 5.0, 6.0, 6.1, 7.0, 7.5, 8.0, 8.6, 8.9, and 9.0. +To reduce compile time you can specify the architectures to build CUTLASS for by changing the CMake configuration setting `CUTLASS_NVCC_ARCHS`. @@ -191,7 +234,7 @@ CUTLASS is arranged as a header-only library along with Utilities, Tools, Exampl and template concepts defined in the CUTLASS project. A detailed explanation of the source code organization may be found in the -[CUTLASS documentation](media/docs/code_organization.md), but several main components are summarized below. +[CUTLASS documentation](./media/docs/code_organization.md), but several main components are summarized below. ## CUTLASS Template Library @@ -204,6 +247,8 @@ include/ # client applications should target this directory conv/ # code specialized for convolution + epilogue/ # code specialized for the epilogue of gemm/convolution + gemm/ # code specialized for general matrix product computations layout/ # layout definitions for matrices, tensors, and other mathematical objects in memory @@ -211,58 +256,34 @@ include/ # client applications should target this directory platform/ # CUDA-capable Standard Library components reduction/ # bandwidth-limited reduction kernels that do not fit the "gemm" model + + thread/ # simt code that can be performed within a CUDA thread transform/ # code specialized for layout, type, and domain transformations * # core vocabulary types, containers, and basic numeric operations -``` - -### CUTLASS SDK Examples - -[CUTLASS SDK examples](/examples) apply CUTLASS templates to implement basic computations. - -``` -examples/ - 00_basic_gemm/ # launches a basic GEMM with single precision inputs and outputs - - 01_cutlass_utilities/ # demonstrates CUTLASS Utilities for allocating and initializing tensors - - 02_dump_reg_smem/ # debugging utilities for printing register and shared memory contents - - 03_visualize_layout/ # utility for visualizing all layout functions in CUTLASS - - 04_tile_iterator/ # example demonstrating an iterator over tiles in memory - - 05_batched_gemm/ # example demonstrating CUTLASS's batched strided GEMM operation - - 06_splitK_gemm/ # exmaple demonstrating CUTLASS's Split-K parallel reduction kernel - - 07_volta_tensorop_gemm/ # example demonstrating mixed precision GEMM using Volta Tensor Cores - 08_turing_tensorop_gemm/ # example demonstrating integer GEMM using Turing Tensor Cores + cute/ # CuTe Layout, layout algebra, MMA/Copy atoms, tiled MMA/Copy - 09_turing_tensorop_conv2dfprop/ # example demonstrating integer implicit GEMM convolution (forward propagation) using Turing Tensor Cores + algorithm/ # Definitions of core operations such as copy, gemm, and operations on cute::tuples - 10_planar_complex/ # example demonstrating planar complex GEMM kernels + arch/ # Bare bones PTX wrapper structs for copy and math instructions - 11_planar_complex_array/ # example demonstrating planar complex kernels with batch-specific problem sizes + atom/ # Meta-information either link to or built from arch/ operators - 12_gemm_bias_relu/ # example demonstrating GEMM fused with bias and relu + mma_atom.hpp # cute::Mma_Atom and cute::TiledMma - 13_fused_two_gemms/ # example demonstrating two GEMms fused in one kernel + copy_atom.hpp # cute::Copy_Atom and cute::TiledCopy - 22_ampere_tensorop_conv2dfprop/ # example demonstrating integer implicit GEMM convolution (forward propagation) using Ampere Tensor Cores + *sm*.hpp # Arch specific meta-information for copy and math operations - 31_basic_syrk # example demonstrating Symetric rank-K update + * # Core library types such as Shape, Stride, Layout, Tensor, and associated operations - 32_basic_trmm # +``` - 33_ampere_3xtf32_tensorop_symm # +### CUTLASS SDK Examples - 35_gemm_softmax # example demonstrating GEMM fused with Softmax in mixed precision using Ampere Tensor Cores - - 40_cutlass_py # example demonstrating CUTLASS with CUDA Python -``` +[CUTLASS SDK examples](./examples) apply CUTLASS templates to implement basic computations. ### Tools @@ -287,7 +308,7 @@ tools/ The `test/unit/` directory consist of unit tests implemented with Google Test that demonstrate basic usage of Core API components and complete tests of the CUTLASS GEMM computations. -Instructions for building and running the Unit tests are described in the [Quickstart guide](media/docs/quickstart.md). +Instructions for building and running the Unit tests are described in the [Quickstart guide](./media/docs/quickstart.md). # Performance Profiling @@ -301,9 +322,11 @@ $ make cutlass_profiler -j16 By default, only one tile size is instantiated for each data type, math instruction, and layout. To instantiate all, set the following environment variable when running CMake from an empty `build/` directory. -Beware, this results in *thousands* of kernels and long build times. +Beware, this results in *tens of thousands* of kernels and long build times. +This would also result in a large binary size and on some platforms linker to fail on building the library. +Therefore, it's highly recommended to generate only a subset of kernels as demonstrated in the sub-section below. ```bash -$ cmake .. -DCUTLASS_NVCC_ARCHS=75 -DCUTLASS_LIBRARY_KERNELS=all +$ cmake .. -DCUTLASS_NVCC_ARCHS=90a -DCUTLASS_LIBRARY_KERNELS=all ... $ make cutlass_profiler -j16 ``` @@ -316,7 +339,7 @@ or a subset of kernels for NVIDIA Ampere and Turing architecture: ### Building a subset Tensor Core GEMM kernels -To compile a subset of Tensor Core GEMM kernels with FP32 accumulation and FP16 input targetting NVIDIA Ampere and Turing architecture, +To compile a subset of Tensor Core GEMM kernels with FP32 accumulation and FP16 input targeting NVIDIA Ampere and Turing architecture, use the below cmake command line: ```bash $ cmake .. -DCUTLASS_NVCC_ARCHS='75;80' -DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_s*gemm_f16_*_nt_align8 @@ -364,7 +387,7 @@ reference_device: Passed ### Building one CUDA Core GEMM kernel -To compile one SGEMM kernel targetting NVIDIA Ampere and Turing architecture, use the below cmake command line: +To compile one SGEMM kernel targeting NVIDIA Ampere and Turing architecture, use the below cmake command line: ```bash $ cmake .. -DCUTLASS_NVCC_ARCHS='75;80' -DCUTLASS_LIBRARY_KERNELS=cutlass_simt_sgemm_128x128_8x2_nn_align1 ... @@ -406,7 +429,7 @@ $ ./tools/profiler/cutlass_profiler --kernels=sgemm --m=3456 --n=4096 --k=4096 ### Building a subset of Tensor Core Convolution kernels To compile a subset of Tensor core convolution kernels implementing forward propagation (fprop) with FP32 accumulation -and FP16 input targetting NVIDIA Ampere and Turing architecture, use the below cmake command line: +and FP16 input targeting NVIDIA Ampere and Turing architecture, use the below cmake command line: ```bash $ cmake .. -DCUTLASS_NVCC_ARCHS='75;80' -DCUTLASS_LIBRARY_KERNELS=cutlass_tensorop_s*fprop_optimized_f16 ... @@ -454,7 +477,7 @@ reference_device: Passed ### Building one Convolution CUDA kernel To compile and run one CUDA Core convolution kernel implementing forward propagation (fprop) with F32 accumulation -and FP32 input targetting NVIDIA Ampere and Turing architecture, use the below cmake command line: +and FP32 input targeting NVIDIA Ampere and Turing architecture, use the below cmake command line: ```bash $ cmake .. -DCUTLASS_NVCC_ARCHS='75;80' -DCUTLASS_LIBRARY_KERNELS=cutlass_simt_sfprop_optimized_128x128_8x2_nhwc ... @@ -501,9 +524,9 @@ reference_device: Passed ## More Details on Compiling CUTLASS Kernels and CUTLASS Profiler - Please follow the links for more CMake examples on selectively compiling CUTLASS kernels: - - [GEMM CMake Examples](media/docs/quickstart.md#gemm-cmake-examples) - - [Implicit GEMM conovlution CMake Examples](media/docs/quickstart.md#convolution-cmake-examples) -- [Further details about the CUTLASS Profiler are described here.](media/docs/profiler.md) + - [GEMM CMake Examples](./media/docs/quickstart.md#gemm-cmake-examples) + - [Implicit GEMM convolution CMake Examples](./media/docs/quickstart.md#convolution-cmake-examples) +- [Further details about the CUTLASS Profiler are described here.](./media/docs/profiler.md) # About @@ -517,7 +540,7 @@ The official list of CUTLASS developers and contributors is available here: [CON # Copyright -Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. SPDX-License-Identifier: BSD-3-Clause ``` @@ -546,4 +569,3 @@ SPDX-License-Identifier: BSD-3-Clause OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. ``` - diff --git a/bin2hex.cmake b/bin2hex.cmake index b0773dd659..b34e02849f 100644 --- a/bin2hex.cmake +++ b/bin2hex.cmake @@ -1,3 +1,31 @@ +# Copyright (c) 2019 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + # A small utility function which generates a C-header from an input file function(FILE_TO_C_STRING FILENAME VARIABLE_NAME OUTPUT_STRING ZERO_TERMINATED) FILE(READ "${FILENAME}" HEX_INPUT HEX) @@ -6,7 +34,7 @@ function(FILE_TO_C_STRING FILENAME VARIABLE_NAME OUTPUT_STRING ZERO_TERMINATED) endif() string(REGEX REPLACE "(....)" "\\1\n" HEX_OUTPUT ${HEX_INPUT}) - string(REGEX REPLACE "([0-9a-f][0-9a-f])" "0x\\1," HEX_OUTPUT ${HEX_OUTPUT}) + string(REGEX REPLACE "([0-9a-f][0-9a-f])" "char(0x\\1)," HEX_OUTPUT ${HEX_OUTPUT}) set(HEX_OUTPUT "static char const ${VARIABLE_NAME}[] = {\n ${HEX_OUTPUT}\n};\n") diff --git a/cmake/CTestTestfile.config.cmake b/cmake/CTestTestfile.config.cmake deleted file mode 100644 index 0705b19c12..0000000000 --- a/cmake/CTestTestfile.config.cmake +++ /dev/null @@ -1,21 +0,0 @@ -# Generated file - -if (DEFINED ENV{CUTLASS_TEST_EXECUTION_ENVIRONMENT}) - set(_CUTLASS_TEST_EXECUTION_ENVIRONMENT $ENV{CUTLASS_TEST_EXECUTION_ENVIRONMENT}) -else() - set(_CUTLASS_TEST_EXECUTION_ENVIRONMENT @CUTLASS_TEST_EXECUTION_ENVIRONMENT@) -endif() - -if (NOT "@TEST_EXE_DIR@" STREQUAL "") - set(TEST_EXE_PATH @TEST_EXE_DIR@/@TEST_EXE@) -else() - set(TEST_EXE_PATH @TEST_EXE@) -endif() - -add_test("@TEST_NAME@" ${_CUTLASS_TEST_EXECUTION_ENVIRONMENT} "${TEST_EXE_PATH}" @TEST_COMMAND_OPTIONS@) - -if (NOT "@TEST_EXE_WORKING_DIRECTORY@" STREQUAL "") - set_tests_properties("@TEST_NAME@" PROPERTIES WORKING_DIRECTORY "@TEST_EXE_WORKING_DIRECTORY@") -endif() - -set_tests_properties(@TEST_NAME@ PROPERTIES DISABLED @__DISABLE_TESTS@) diff --git a/cmake/CTestTestfile.configure.cmake b/cmake/CTestTestfile.configure.cmake new file mode 100644 index 0000000000..611b3d181f --- /dev/null +++ b/cmake/CTestTestfile.configure.cmake @@ -0,0 +1,52 @@ +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# Generated file + +set(TEST_SETS_SUPPORTED @TEST_SETS_SUPPORTED@) + +if (NOT DEFINED ENV{CUTLASS_TEST_SETS}) + set(ENV{CUTLASS_TEST_SETS} @CUTLASS_DEFAULT_ACTIVE_TEST_SETS@) +endif() + +foreach(TEST_SET_REQUESTED IN ITEMS $ENV{CUTLASS_TEST_SETS}) + if (NOT TEST_SET_REQUESTED IN_LIST TEST_SETS_SUPPORTED) + message(STATUS "Skipping tests for @TEST_EXE_PATH@ as ${TEST_SET_REQUESTED} is not in the set of [${TEST_SETS_SUPPORTED}].") + return() + endif() +endforeach() + +set(TEST_EXE_PATH @TEST_EXE_PATH@) +set(TEST_EXE_WORKING_DIRECTORY @TEST_EXE_WORKING_DIRECTORY@) +set(CUTLASS_USE_EXTENDED_ADD_TEST_FORMAT @TEST_USE_EXTENDED_FORMAT@) + +if (DEFINED ENV{CUTLASS_TEST_EXECUTION_ENVIRONMENT}) + set(_CUTLASS_TEST_EXECUTION_ENVIRONMENT $ENV{CUTLASS_TEST_EXECUTION_ENVIRONMENT}) +else() + set(_CUTLASS_TEST_EXECUTION_ENVIRONMENT @CUTLASS_TEST_EXECUTION_ENVIRONMENT@) +endif() diff --git a/cmake/CTestTestfile.test.configure.cmake b/cmake/CTestTestfile.test.configure.cmake new file mode 100644 index 0000000000..31dba54498 --- /dev/null +++ b/cmake/CTestTestfile.test.configure.cmake @@ -0,0 +1,43 @@ +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +if (CUTLASS_USE_EXTENDED_ADD_TEST_FORMAT) + # The longform/extended format allows generator expressions to be + # expanded property and is useful in contexts where the files need + # to be immediately included into being-processed cmake code. + add_test(NAME @TESTCASE_NAME@ COMMAND ${_CUTLASS_TEST_EXECUTION_ENVIRONMENT} "${TEST_EXE_PATH}" @TEST_COMMAND_OPTIONS@) +else() + add_test(@TESTCASE_NAME@ ${_CUTLASS_TEST_EXECUTION_ENVIRONMENT} "${TEST_EXE_PATH}" @TEST_COMMAND_OPTIONS@) +endif() + +if (TEST_EXE_WORKING_DIRECTORY) + set_tests_properties(@TESTCASE_NAME@ PROPERTIES WORKING_DIRECTORY "${TEST_EXE_WORKING_DIRECTORY}") +endif() + +set_tests_properties(@TESTCASE_NAME@ PROPERTIES DISABLED @__DISABLE_TESTS@) + diff --git a/cmake/NvidiaCutlassConfig.cmake b/cmake/NvidiaCutlassConfig.cmake.in similarity index 52% rename from cmake/NvidiaCutlassConfig.cmake rename to cmake/NvidiaCutlassConfig.cmake.in index 701ecb4af4..2fe69119a1 100644 --- a/cmake/NvidiaCutlassConfig.cmake +++ b/cmake/NvidiaCutlassConfig.cmake.in @@ -2,6 +2,8 @@ get_filename_component(NvidiaCutlass_CMAKE_DIR "${CMAKE_CURRENT_LIST_FILE}" PATH include(CMakeFindDependencyMacro) -if(NOT TARGET nvidia::cutlass::CUTLASS) - include("${NvidiaCutlass_CMAKE_DIR}/NvidiaCutlassTargets.cmake") +if(TARGET nvidia::cutlass::CUTLASS) + return() endif() + +include("${NvidiaCutlass_CMAKE_DIR}/NvidiaCutlassTargets.cmake") diff --git a/cmake/NvidiaCutlassPackageConfig.cmake b/cmake/NvidiaCutlassPackageConfig.cmake index bb15b1bb70..364fba7a20 100644 --- a/cmake/NvidiaCutlassPackageConfig.cmake +++ b/cmake/NvidiaCutlassPackageConfig.cmake @@ -1,3 +1,31 @@ +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + set(CPACK_PACKAGE_NAME NvidiaCutlass) set(CPACK_PACKAGE_VENDOR NVIDIA) set(CPACK_PACKAGE_CONTACT info@nvidia.com) diff --git a/cmake/googletest.cmake b/cmake/googletest.cmake index 85edc807c9..d220cfadc2 100644 --- a/cmake/googletest.cmake +++ b/cmake/googletest.cmake @@ -1,3 +1,31 @@ +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + include(FetchContent) set(GOOGLETEST_DIR "" CACHE STRING "Location of local GoogleTest repo to build against") @@ -6,10 +34,11 @@ if(GOOGLETEST_DIR) set(FETCHCONTENT_SOURCE_DIR_GOOGLETEST ${GOOGLETEST_DIR} CACHE STRING "GoogleTest source directory override") endif() +set(GTEST_REPOSITORY "https://github.com/google/googletest.git" CACHE STRING "GoogleTest repo to fetch") FetchContent_Declare( googletest - GIT_REPOSITORY https://github.com/google/googletest.git - GIT_TAG 0fe9660 + GIT_REPOSITORY ${GTEST_REPOSITORY} + GIT_TAG v1.14.0 ) FetchContent_GetProperties(googletest) diff --git a/cmake/nop.cu b/cmake/nop.cu index f477557225..be2b15881e 100644 --- a/cmake/nop.cu +++ b/cmake/nop.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/cmake/version.h.in b/cmake/version.h.in deleted file mode 100644 index 1b48e1abc2..0000000000 --- a/cmake/version.h.in +++ /dev/null @@ -1,38 +0,0 @@ -#include -#include - -#define CUTLASS_MAJOR @CUTLASS_VERSION_MAJOR@ -#define CUTLASS_MINOR @CUTLASS_VERSION_MINOR@ -#define CUTLASS_PATCH @CUTLASS_VERSION_PATCH@ -#define CUTLASS_BUILD @CUTLASS_VERSION_BUILD@ -#define CUTLASS_VERSION ((CUTLASS_MAJOR)*100 + (CUTLASS_MINOR)*10 + CUTLASS_PATCH) - -namespace cutlass { - - inline uint32_t getVersion() { - return CUTLASS_VERSION; - } - inline uint32_t getVersionMajor() { - return CUTLASS_MAJOR; - } - inline uint32_t getVersionMinor() { - return CUTLASS_MINOR; - } - inline uint32_t getVersionPatch() { - return CUTLASS_PATCH; - } - inline uint32_t getVersionBuild() { - return CUTLASS_BUILD + 0; - } - inline std::string getVersionString() { - std::string version = "@CUTLASS_VERSION@"; - if (getVersionBuild()) { - version += "." + std::to_string(getVersionBuild()); - } - return version; - } - inline std::string getGitRevision() { - return "@CUTLASS_REVISION@"; - } - -} // namespace cutlass diff --git a/cmake/version_extended.h.in b/cmake/version_extended.h.in new file mode 100644 index 0000000000..3613063022 --- /dev/null +++ b/cmake/version_extended.h.in @@ -0,0 +1,34 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#define CUTLASS_BUILD @CUTLASS_VERSION_BUILD@ +#define CUTLASS_REVISION "@CUTLASS_REVISION@" diff --git a/cuBLAS.cmake b/cuBLAS.cmake index a66274eb78..383871fdff 100644 --- a/cuBLAS.cmake +++ b/cuBLAS.cmake @@ -1,4 +1,4 @@ -# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without @@ -40,7 +40,7 @@ elseif(NOT TARGET cublas) find_path( _CUBLAS_INCLUDE_DIR - NAMES cublas.h + NAMES cublas_v2.h HINTS ${CUBLAS_INCLUDE_PATH} ENV CUBLAS_INCLUDE_PATH diff --git a/cuDNN.cmake b/cuDNN.cmake index 4f89f43dd2..0b37ff7c30 100644 --- a/cuDNN.cmake +++ b/cuDNN.cmake @@ -1,4 +1,4 @@ -# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without diff --git a/docs/annotated.html b/docs/annotated.html index d587a26d9a..233691c235 100644 --- a/docs/annotated.html +++ b/docs/annotated.html @@ -280,15 +280,15 @@  CDefaultGemmConfiguration< arch::OpClassWmmaTensorOp, ArchTag, ElementA, ElementB, ElementC, ElementAccumulator >  CGemm  CArgumentsArgument structure - CGemm< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial, Operator_, IsBetaZero >Parital specialization for column-major output exchanges problem size and operand + CGemm< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial, Operator_, IsBetaZero >Partial specialization for column-major output exchanges problem size and operand  CArgumentsArgument structure  CGemmBatched  CArgumentsArgument structure - CGemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >Parital specialization for column-major output exchanges problem size and operand + CGemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >Partial specialization for column-major output exchanges problem size and operand  CArgumentsArgument structure  CGemmComplex  CArgumentsArgument structure - CGemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >Parital specialization for column-major output exchanges problem size and operand + CGemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial >Partial specialization for column-major output exchanges problem size and operand  CArgumentsArgument structure  CGemmSplitKParallel  CArgumentsArgument structure @@ -594,7 +594,7 @@  CGemm  CGemm< ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType, AccumulatorType, arch::OpMultiplyAdd >Partial specialization for multiply-add  CGemm< ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType, AccumulatorType, arch::OpMultiplyAddSaturate >Partial specialization for multiply-add-saturate - CGemm< ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType, AccumulatorType, arch::OpXorPopc >Parital specialization for XOR-popc + CGemm< ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType, AccumulatorType, arch::OpXorPopc >Partial specialization for XOR-popc  CTensorDiagonalForEachLaunches a kernel calling a functor for each element along a tensor's diagonal  CTensorForEachLaunches a kernel calling a functor for each element in a tensor's index space  Nhost @@ -620,7 +620,7 @@  CGemm  CGemm< ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType, ComputeType, arch::OpMultiplyAdd >Partial specialization for multiply-add  CGemm< ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType, ComputeType, arch::OpMultiplyAddSaturate >Partial specialization for multiply-add-saturate - CGemm< ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType, ComputeType, arch::OpXorPopc >Parital specialization for XOR-popc + CGemm< ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType, ComputeType, arch::OpXorPopc >Partial specialization for XOR-popc  Nthread  CMatrixPer-thread matrix object storing a packed matrix  Ntransform @@ -718,7 +718,7 @@  CDetailInternal details made public to facilitate introspection Iterations along each dimension (concept: PitchLinearShape)  CTransposePitchLinearThreadMap  CDetailInternal details made public to facilitate introspection Iterations along each dimension (concept: PitchLinearShape) - CTransposePitchLinearThreadMap2DThreadTileThread Mapping a 2D threadtiled mapping as a tranposed Pitchlinear2DThreadTile mapping + CTransposePitchLinearThreadMap2DThreadTileThread Mapping a 2D threadtiled mapping as a transposed Pitchlinear2DThreadTile mapping  CTransposePitchLinearThreadMapSimt  CAlignedArrayAligned array type  CAlignedBufferModifies semantics of cutlass::Array<> to provide guaranteed alignment diff --git a/docs/classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA___00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html b/docs/classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA___00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html index d53d8d67dc..6800f4fe70 100644 --- a/docs/classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA___00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html +++ b/docs/classcutlass_1_1gemm_1_1device_1_1GemmBatched_3_01ElementA___00_01LayoutA___00_01ElementB___00_0c9bb6f4463ab6085e6008b5d5ad6abfd.html @@ -108,7 +108,7 @@
-

Parital specialization for column-major output exchanges problem size and operand. +

Partial specialization for column-major output exchanges problem size and operand.

#include <gemm_batched.h>

diff --git a/docs/classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA___00_01LayoutA___00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html b/docs/classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA___00_01LayoutA___00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html index 8468585615..d09783439c 100644 --- a/docs/classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA___00_01LayoutA___00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html +++ b/docs/classcutlass_1_1gemm_1_1device_1_1GemmComplex_3_01ElementA___00_01LayoutA___00_01ElementB___00_07c56401b4df75709ae636675d9980a9a.html @@ -108,7 +108,7 @@
-

Parital specialization for column-major output exchanges problem size and operand. +

Partial specialization for column-major output exchanges problem size and operand.

#include <gemm_complex.h>

diff --git a/docs/classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA___00_01LayoutA___00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html b/docs/classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA___00_01LayoutA___00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html index f34be6e5ee..323ecfc2b0 100644 --- a/docs/classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA___00_01LayoutA___00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html +++ b/docs/classcutlass_1_1gemm_1_1device_1_1Gemm_3_01ElementA___00_01LayoutA___00_01ElementB___00_01Layout4d0960ae6b1d1bf19e6239dbd002249c.html @@ -108,7 +108,7 @@
-

Parital specialization for column-major output exchanges problem size and operand. +

Partial specialization for column-major output exchanges problem size and operand.

#include <gemm.h>

diff --git a/docs/command__line_8h_source.html b/docs/command__line_8h_source.html index f98c9f1eaf..325a303405 100644 --- a/docs/command__line_8h_source.html +++ b/docs/command__line_8h_source.html @@ -98,7 +98,7 @@
command_line.h
-Go to the documentation of this file.
1 /******************************************************************************
2  * Copyright (c) 2011-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without
5  * modification, are not permitted.
6  *
7  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
8  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
9  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
10  * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
11  * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
12  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
13  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
14  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
15  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
16  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
17  *
18  ******************************************************************************/
19 
20 #pragma once
21 
27 #include <iostream>
28 #include <limits>
29 #include <sstream>
30 #include <string>
31 #include <vector>
32 
33 #include <cuda_runtime.h>
34 
35 namespace cutlass {
36 
37 /******************************************************************************
38  * command_line
39  ******************************************************************************/
40 
44 struct CommandLine {
45  std::vector<std::string> keys;
46  std::vector<std::string> values;
47  std::vector<std::string> args;
48 
52  CommandLine(int argc, const char** argv) {
53  using namespace std;
54 
55  for (int i = 1; i < argc; i++) {
56  string arg = argv[i];
57 
58  if ((arg[0] != '-') || (arg[1] != '-')) {
59  args.push_back(arg);
60  continue;
61  }
62 
63  string::size_type pos;
64  string key, val;
65  if ((pos = arg.find('=')) == string::npos) {
66  key = string(arg, 2, arg.length() - 2);
67  val = "";
68  } else {
69  key = string(arg, 2, pos - 2);
70  val = string(arg, pos + 1, arg.length() - 1);
71  }
72 
73  keys.push_back(key);
74  values.push_back(val);
75  }
76  }
77 
81  bool check_cmd_line_flag(const char* arg_name) const {
82  using namespace std;
83 
84  for (int i = 0; i < int(keys.size()); ++i) {
85  if (keys[i] == string(arg_name)) return true;
86  }
87  return false;
88  }
89 
93  template <typename value_t>
94  int num_naked_args() const {
95  return args.size();
96  }
97 
101  template <typename value_t>
102  void get_cmd_line_argument(int index, value_t& val) const {
103  using namespace std;
104  if (index < args.size()) {
105  istringstream str_stream(args[index]);
106  str_stream >> val;
107  }
108  }
109 
113  void get_cmd_line_argument(const char* arg_name, bool& val, bool _default = true) const {
114  val = _default;
115  if (check_cmd_line_flag(arg_name)) {
116  std::string value;
117  get_cmd_line_argument(arg_name, value);
118 
119  val = !(value == "0" || value == "false");
120  }
121  }
122 
126  template <typename value_t>
127  void get_cmd_line_argument(const char* arg_name,
128  value_t& val,
129  value_t const& _default = value_t()) const {
130  using namespace std;
131 
132  val = _default;
133 
134  for (int i = 0; i < int(keys.size()); ++i) {
135  if (keys[i] == string(arg_name)) {
136  istringstream str_stream(values[i]);
137  str_stream >> val;
138  }
139  }
140  }
141 
145  template <typename value_t>
146  void get_cmd_line_arguments(const char* arg_name,
147  std::vector<value_t>& vals,
148  char sep = ',') const {
149  using namespace std;
150 
151  if (check_cmd_line_flag(arg_name)) {
152  // Clear any default values
153  vals.clear();
154 
155  // Recover from multi-value string
156  for (int i = 0; i < keys.size(); ++i) {
157  if (keys[i] == string(arg_name)) {
158  string val_string(values[i]);
159  seperate_string(val_string, vals, sep);
160  }
161  }
162  }
163  }
164 
169  void get_cmd_line_argument_pairs(const char* arg_name,
170  std::vector<std::pair<std::string, std::string> >& tokens,
171  char delim = ',',
172  char sep = ':') const {
173  if (check_cmd_line_flag(arg_name)) {
174  std::string value;
175  get_cmd_line_argument(arg_name, value);
176 
177  tokenize(tokens, value, delim, sep);
178  }
179  }
180 
185  void get_cmd_line_argument_ranges(const char* arg_name,
186  std::vector<std::vector<std::string> >& vals,
187  char delim = ',',
188  char sep = ':') const {
189  std::vector<std::string> ranges;
190  get_cmd_line_arguments(arg_name, ranges, delim);
191 
192  for (std::vector<std::string>::const_iterator range = ranges.begin();
193  range != ranges.end(); ++range) {
194 
195  std::vector<std::string> range_vals;
196  seperate_string(*range, range_vals, sep);
197  vals.push_back(range_vals);
198  }
199  }
200 
204  int parsed_argc() const { return (int)keys.size(); }
205 
206  //-------------------------------------------------------------------------
207  // Utility functions
208  //-------------------------------------------------------------------------
209 
211  static void tokenize(std::vector<std::pair<std::string, std::string> >& tokens,
212  std::string const& str,
213  char delim = ',',
214  char sep = ':') {
215  // Home-built to avoid Boost dependency
216  size_t s_idx = 0;
217  size_t d_idx = std::string::npos;
218  while (s_idx < str.size()) {
219  d_idx = str.find_first_of(delim, s_idx);
220 
221  size_t end_idx = (d_idx != std::string::npos ? d_idx : str.size());
222  size_t sep_idx = str.find_first_of(sep, s_idx);
223  size_t offset = 1;
224  if (sep_idx == std::string::npos || sep_idx >= end_idx) {
225  sep_idx = end_idx;
226  offset = 0;
227  }
228 
229  std::pair<std::string, std::string> item(
230  str.substr(s_idx, sep_idx - s_idx),
231  str.substr(sep_idx + offset, end_idx - sep_idx - offset));
232 
233  tokens.push_back(item);
234  s_idx = end_idx + 1;
235  }
236  }
237 
239  static void tokenize(std::vector<std::string>& tokens,
240  std::string const& str,
241  char delim = ',',
242  char sep = ':') {
243  typedef std::vector<std::pair<std::string, std::string> > TokenVector;
244  typedef TokenVector::const_iterator token_iterator;
245 
246  std::vector<std::pair<std::string, std::string> > token_pairs;
247  tokenize(token_pairs, str, delim, sep);
248  for (token_iterator tok = token_pairs.begin(); tok != token_pairs.end(); ++tok) {
249  tokens.push_back(tok->first);
250  }
251  }
252 
253  template <typename value_t>
254  static void seperate_string(std::string const& str,
255  std::vector<value_t>& vals,
256  char sep = ',') {
257  std::istringstream str_stream(str);
258  std::string::size_type old_pos = 0;
259  std::string::size_type new_pos = 0;
260 
261  // Iterate <sep>-delimited values
262  value_t val;
263  while ((new_pos = str.find(sep, old_pos)) != std::string::npos) {
264  if (new_pos != old_pos) {
265  str_stream.width(new_pos - old_pos);
266  str_stream >> val;
267  vals.push_back(val);
268  }
269 
270  // skip over delimiter
271  str_stream.ignore(1);
272  old_pos = new_pos + 1;
273  }
274 
275  // Read last value
276  str_stream >> val;
277  vals.push_back(val);
278  }
279 };
280 
281 } // namespace cutlass
Definition: aligned_buffer.h:35
+Go to the documentation of this file.
1 /******************************************************************************
2  * Copyright (c) 2011-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without
5  * modification, are not permitted.
6  *
7  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
8  * ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
9  * WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
10  * DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE FOR ANY
11  * DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES
12  * (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES;
13  * LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND
14  * ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT
15  * (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS
16  * SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
17  *
18  ******************************************************************************/
19 
20 #pragma once
21 
27 #include <iostream>
28 #include <limits>
29 #include <sstream>
30 #include <string>
31 #include <vector>
32 
33 #include <cuda_runtime.h>
34 
35 namespace cutlass {
36 
37 /******************************************************************************
38  * command_line
39  ******************************************************************************/
40 
44 struct CommandLine {
45  std::vector<std::string> keys;
46  std::vector<std::string> values;
47  std::vector<std::string> args;
48 
52  CommandLine(int argc, const char** argv) {
53  using namespace std;
54 
55  for (int i = 1; i < argc; i++) {
56  string arg = argv[i];
57 
58  if ((arg[0] != '-') || (arg[1] != '-')) {
59  args.push_back(arg);
60  continue;
61  }
62 
63  string::size_type pos;
64  string key, val;
65  if ((pos = arg.find('=')) == string::npos) {
66  key = string(arg, 2, arg.length() - 2);
67  val = "";
68  } else {
69  key = string(arg, 2, pos - 2);
70  val = string(arg, pos + 1, arg.length() - 1);
71  }
72 
73  keys.push_back(key);
74  values.push_back(val);
75  }
76  }
77 
81  bool check_cmd_line_flag(const char* arg_name) const {
82  using namespace std;
83 
84  for (int i = 0; i < int(keys.size()); ++i) {
85  if (keys[i] == string(arg_name)) return true;
86  }
87  return false;
88  }
89 
93  template <typename value_t>
94  int num_naked_args() const {
95  return args.size();
96  }
97 
101  template <typename value_t>
102  void get_cmd_line_argument(int index, value_t& val) const {
103  using namespace std;
104  if (index < args.size()) {
105  istringstream str_stream(args[index]);
106  str_stream >> val;
107  }
108  }
109 
113  void get_cmd_line_argument(const char* arg_name, bool& val, bool _default = true) const {
114  val = _default;
115  if (check_cmd_line_flag(arg_name)) {
116  std::string value;
117  get_cmd_line_argument(arg_name, value);
118 
119  val = !(value == "0" || value == "false");
120  }
121  }
122 
126  template <typename value_t>
127  void get_cmd_line_argument(const char* arg_name,
128  value_t& val,
129  value_t const& _default = value_t()) const {
130  using namespace std;
131 
132  val = _default;
133 
134  for (int i = 0; i < int(keys.size()); ++i) {
135  if (keys[i] == string(arg_name)) {
136  istringstream str_stream(values[i]);
137  str_stream >> val;
138  }
139  }
140  }
141 
145  template <typename value_t>
146  void get_cmd_line_arguments(const char* arg_name,
147  std::vector<value_t>& vals,
148  char sep = ',') const {
149  using namespace std;
150 
151  if (check_cmd_line_flag(arg_name)) {
152  // Clear any default values
153  vals.clear();
154 
155  // Recover from multi-value string
156  for (int i = 0; i < keys.size(); ++i) {
157  if (keys[i] == string(arg_name)) {
158  string val_string(values[i]);
159  separate_string(val_string, vals, sep);
160  }
161  }
162  }
163  }
164 
169  void get_cmd_line_argument_pairs(const char* arg_name,
170  std::vector<std::pair<std::string, std::string> >& tokens,
171  char delim = ',',
172  char sep = ':') const {
173  if (check_cmd_line_flag(arg_name)) {
174  std::string value;
175  get_cmd_line_argument(arg_name, value);
176 
177  tokenize(tokens, value, delim, sep);
178  }
179  }
180 
185  void get_cmd_line_argument_ranges(const char* arg_name,
186  std::vector<std::vector<std::string> >& vals,
187  char delim = ',',
188  char sep = ':') const {
189  std::vector<std::string> ranges;
190  get_cmd_line_arguments(arg_name, ranges, delim);
191 
192  for (std::vector<std::string>::const_iterator range = ranges.begin();
193  range != ranges.end(); ++range) {
194 
195  std::vector<std::string> range_vals;
196  separate_string(*range, range_vals, sep);
197  vals.push_back(range_vals);
198  }
199  }
200 
204  int parsed_argc() const { return (int)keys.size(); }
205 
206  //-------------------------------------------------------------------------
207  // Utility functions
208  //-------------------------------------------------------------------------
209 
211  static void tokenize(std::vector<std::pair<std::string, std::string> >& tokens,
212  std::string const& str,
213  char delim = ',',
214  char sep = ':') {
215  // Home-built to avoid Boost dependency
216  size_t s_idx = 0;
217  size_t d_idx = std::string::npos;
218  while (s_idx < str.size()) {
219  d_idx = str.find_first_of(delim, s_idx);
220 
221  size_t end_idx = (d_idx != std::string::npos ? d_idx : str.size());
222  size_t sep_idx = str.find_first_of(sep, s_idx);
223  size_t offset = 1;
224  if (sep_idx == std::string::npos || sep_idx >= end_idx) {
225  sep_idx = end_idx;
226  offset = 0;
227  }
228 
229  std::pair<std::string, std::string> item(
230  str.substr(s_idx, sep_idx - s_idx),
231  str.substr(sep_idx + offset, end_idx - sep_idx - offset));
232 
233  tokens.push_back(item);
234  s_idx = end_idx + 1;
235  }
236  }
237 
239  static void tokenize(std::vector<std::string>& tokens,
240  std::string const& str,
241  char delim = ',',
242  char sep = ':') {
243  typedef std::vector<std::pair<std::string, std::string> > TokenVector;
244  typedef TokenVector::const_iterator token_iterator;
245 
246  std::vector<std::pair<std::string, std::string> > token_pairs;
247  tokenize(token_pairs, str, delim, sep);
248  for (token_iterator tok = token_pairs.begin(); tok != token_pairs.end(); ++tok) {
249  tokens.push_back(tok->first);
250  }
251  }
252 
253  template <typename value_t>
254  static void separate_string(std::string const& str,
255  std::vector<value_t>& vals,
256  char sep = ',') {
257  std::istringstream str_stream(str);
258  std::string::size_type old_pos = 0;
259  std::string::size_type new_pos = 0;
260 
261  // Iterate <sep>-delimited values
262  value_t val;
263  while ((new_pos = str.find(sep, old_pos)) != std::string::npos) {
264  if (new_pos != old_pos) {
265  str_stream.width(new_pos - old_pos);
266  str_stream >> val;
267  vals.push_back(val);
268  }
269 
270  // skip over delimiter
271  str_stream.ignore(1);
272  old_pos = new_pos + 1;
273  }
274 
275  // Read last value
276  str_stream >> val;
277  vals.push_back(val);
278  }
279 };
280 
281 } // namespace cutlass
Definition: aligned_buffer.h:35
void get_cmd_line_argument(const char *arg_name, value_t &val, value_t const &_default=value_t()) const
Definition: command_line.h:127
void get_cmd_line_argument_pairs(const char *arg_name, std::vector< std::pair< std::string, std::string > > &tokens, char delim= ',', char sep= ':') const
Definition: command_line.h:169
STL namespace.
@@ -116,7 +116,7 @@
CommandLine(int argc, const char **argv)
Definition: command_line.h:52
std::vector< std::string > args
Definition: command_line.h:47
Definition: command_line.h:44
-
static void seperate_string(std::string const &str, std::vector< value_t > &vals, char sep= ',')
Definition: command_line.h:254
+
static void separate_string(std::string const &str, std::vector< value_t > &vals, char sep= ',')
Definition: command_line.h:254
int parsed_argc() const
Definition: command_line.h:204
diff --git a/docs/default__mma__core__simt_8h_source.html b/docs/default__mma__core__simt_8h_source.html index 6897c764f0..6e068e5a87 100644 --- a/docs/default__mma__core__simt_8h_source.html +++ b/docs/default__mma__core__simt_8h_source.html @@ -144,7 +144,7 @@
Defines the size of an element in bits.
Definition: numeric_types.h:42
-
Thread Mapping a 2D threadtiled mapping as a tranposed Pitchlinear2DThreadTile mapping.
Definition: pitch_linear_thread_map.h:713
+
Thread Mapping a 2D threadtiled mapping as a transposed Pitchlinear2DThreadTile mapping.
Definition: pitch_linear_thread_map.h:713
Defines basic properties needed by CTA-level GEMMs assuming expectations about data layout of the glo...
diff --git a/docs/device_2gemm__batched_8h.html b/docs/device_2gemm__batched_8h.html index e648c67eea..3ef58bf18e 100644 --- a/docs/device_2gemm__batched_8h.html +++ b/docs/device_2gemm__batched_8h.html @@ -130,7 +130,7 @@  Argument structure. More...
  class  cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ > - Parital specialization for column-major output exchanges problem size and operand. More...
+ Partial specialization for column-major output exchanges problem size and operand. More...
  struct  cutlass::gemm::device::GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ >::Arguments  Argument structure. More...
diff --git a/docs/device_2kernel_2tensor__foreach_8h_source.html b/docs/device_2kernel_2tensor__foreach_8h_source.html index a4839c25f7..21c8fb591f 100644 --- a/docs/device_2kernel_2tensor__foreach_8h_source.html +++ b/docs/device_2kernel_2tensor__foreach_8h_source.html @@ -100,7 +100,7 @@
Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
25 
26 #pragma once
27 
28 #include "cutlass/cutlass.h"
29 #include "cutlass/coord.h"
30 
31 namespace cutlass {
32 namespace reference {
33 namespace device {
34 namespace kernel {
35 
37 
39 namespace detail {
40 
42 template <typename Func, int Rank, int RankRemaining>
44 
46  __inline__ __device__
47  TensorForEachHelper(Func &func, Coord<Rank> const &size, Coord<Rank> &coord, int64_t index) {
48 
49  int64_t product = 1;
50 
52  for (int i = Rank - RankRemaining; i < Rank; ++i) {
53  product *= size[i];
54  }
55 
56  coord[Rank - 1 - RankRemaining] = index / product;
57  int64_t remaining = index % product;
58 
59  TensorForEachHelper<Func, Rank, RankRemaining-1>(func, size, coord, remaining);
60  }
61 };
62 
64 template <typename Func, int Rank>
65 struct TensorForEachHelper<Func, Rank, 0> {
66 
68  __inline__ __device__
69  TensorForEachHelper(Func &func, Coord<Rank> const &size, Coord<Rank> &coord, int64_t index) {
70 
71  coord[Rank - 1] = index;
72 
73  if (coord < size) {
74  func(coord);
75  }
76  }
77 };
78 
79 } // namespace detail
80 
82 
84 template <typename Func, int Rank, typename Params>
85 __global__ void TensorForEach(Coord<Rank> size, Params params = Params()) {
86 
87  Func func(params);
88 
89  int64_t index = threadIdx.x + blockIdx.x * blockDim.x;
90  int64_t max_index = 1;
91 
93  for (int i = 0; i < Rank; ++i) {
94  max_index *= size[i];
95  }
96 
98  while (index < max_index) {
99  Coord<Rank> coord;
100 
101  detail::TensorForEachHelper<Func, Rank, Rank - 1>(func, size, coord, index);
102  index += blockDim.x * gridDim.x;
103  }
104 }
105 
107 
109 template <typename Func, int Rank, typename Params>
110 __global__ void TensorDiagonalForEach(Coord<Rank> size, Params params, int start, int end) {
111 
112  Func func(params);
113 
114  int64_t index = threadIdx.x + blockIdx.x * blockDim.x + start;
115 
116  if (index < end) {
117  Coord<Rank> coord;
118 
120  for (int i = 0; i < Rank; ++i) {
121  coord[i] = index;
122  }
123 
124  func(coord);
125  }
126 }
127 
129 
130 template <typename Element, typename Func>
131 __global__ void BlockForEach(
132  Element *ptr,
133  size_t capacity,
134  typename Func::Params params) {
135 
136  Func func(params);
137 
138  size_t index = threadIdx.x + blockIdx.x * blockDim.x;
139 
140  for (; index < capacity; index += blockDim.x * gridDim.x) {
141  ptr[index] = func();
142  }
143 }
144 
146 
147 } // namespace kernel
148 } // namespace device
149 } // namespace reference
150 } // namespace cutlass
151 
Definition: aligned_buffer.h:35
A Coord is a coordinate of arbitrary rank into a tensor or matrix.
-
__inline__ __device__ TensorForEachHelper(Func &func, Coord< Rank > const &size, Coord< Rank > &coord, int64_t index)
Constructor for fastest chaning rank.
Definition: device/kernel/tensor_foreach.h:69
+
__inline__ __device__ TensorForEachHelper(Func &func, Coord< Rank > const &size, Coord< Rank > &coord, int64_t index)
Constructor for fastest changing rank.
Definition: device/kernel/tensor_foreach.h:69
#define CUTLASS_PRAGMA_UNROLL
Definition: cutlass.h:110
__global__ void BlockForEach(Element *ptr, size_t capacity, typename Func::Params params)
Definition: device/kernel/tensor_foreach.h:131
#define CUTLASS_PRAGMA_NO_UNROLL
Definition: cutlass.h:111
diff --git a/docs/device_2tensor__fill_8h.html b/docs/device_2tensor__fill_8h.html index c2af8c44d3..5a99459cfc 100644 --- a/docs/device_2tensor__fill_8h.html +++ b/docs/device_2tensor__fill_8h.html @@ -237,7 +237,7 @@   template<typename Element , typename Layout > void cutlass::reference::device::TensorFillIdentity (TensorView< Element, Layout > view) - Fills a tensor's digonal with 1 and 0 everywhere else. More...
+ Fills a tensor's diagonal with 1 and 0 everywhere else. More...
  template<typename Element , typename Layout > void cutlass::reference::device::TensorUpdateDiagonal (TensorView< Element, Layout > view, Element diag=Element(1)) diff --git a/docs/device_2tensor__fill_8h_source.html b/docs/device_2tensor__fill_8h_source.html index dd5debdaa8..908ccba847 100644 --- a/docs/device_2tensor__fill_8h_source.html +++ b/docs/device_2tensor__fill_8h_source.html @@ -125,7 +125,7 @@
Parameters structure.
Definition: device/tensor_fill.h:99
Kind kind
Active variant kind.
Definition: distribution.h:64
Params(TensorView view_=TensorView(), typename RandomFunc::Params random_=RandomFunc::Params())
Construction of Gaussian RNG functor.
Definition: device/tensor_fill.h:422
-
void TensorFillIdentity(TensorView< Element, Layout > view)
Fills a tensor&#39;s digonal with 1 and 0 everywhere else.
Definition: device/tensor_fill.h:630
+
void TensorFillIdentity(TensorView< Element, Layout > view)
Fills a tensor&#39;s diagonal with 1 and 0 everywhere else.
Definition: device/tensor_fill.h:630
CUTLASS_HOST_DEVICE TensorCoord const & extent() const
Returns the extent of the view (the size along each logical dimension).
Definition: tensor_view.h:167
Computes a random Gaussian distribution.
Definition: device/tensor_fill.h:645
int int_scale
Definition: device/tensor_fill.h:315
diff --git a/docs/device_2tensor__foreach_8h_source.html b/docs/device_2tensor__foreach_8h_source.html index 90c5402ef2..0380fa93f3 100644 --- a/docs/device_2tensor__foreach_8h_source.html +++ b/docs/device_2tensor__foreach_8h_source.html @@ -98,7 +98,7 @@
device/tensor_foreach.h
-Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
25 #pragma once
26 
27 #include <stdexcept>
28 #include "cutlass/cutlass.h"
30 
31 namespace cutlass {
32 namespace reference {
33 namespace device {
34 
36 
38 template <typename Func, int Rank, typename Params>
39 struct TensorForEach {
40 
42  TensorForEach(Coord<Rank> size, Params params = Params(), int grid_size = 0, int block_size = 0) {
43 
44  if (!grid_size || !block_size) {
45 
46  // if grid_size or block_size are zero, query occupancy using the CUDA Occupancy API
47  cudaError_t result = cudaOccupancyMaxPotentialBlockSize(
48  &grid_size,
49  &block_size,
50  reinterpret_cast<void const *>(kernel::TensorForEach<Func, Rank, Params>));
51 
52  if (result != cudaSuccess) {
53  throw std::runtime_error("Failed to query occupancy.");
54  }
55 
56  // Limit block size. This has the effect of increasing the number of items processed by a
57  // single thread and reduces the impact of initialization overhead.
58  block_size = (block_size < 128 ? block_size : 128);
59  }
60 
61  dim3 grid(grid_size, 1, 1);
62  dim3 block(block_size, 1, 1);
63 
64  kernel::TensorForEach<Func, Rank, Params><<< grid, block >>>(size, params);
65  }
66 };
67 
69 
71 template <typename Func, int Rank, typename Params>
73 
75  TensorDiagonalForEach(Coord<Rank> size, Params params = Params(), int start = 0, int end = -1, int block_size = 128) {
76 
77  if (end < 0) {
78  end = size.min();
79  }
80 
81  dim3 block(block_size, 1, 1);
82  dim3 grid((end - start + block_size - 1) / block_size, 1, 1);
83 
84  kernel::TensorDiagonalForEach<Func, Rank, Params><<< grid, block >>>(size, params, start, end);
85  }
86 };
87 
88 
90 
91 template <typename Element, typename Func>
92 struct BlockForEach {
93 
96  Element *ptr,
97  size_t capacity,
98  typename Func::Params params = typename Func::Params(),
99  int grid_size = 0,
100  int block_size = 0) {
101 
102  if (!grid_size || !block_size) {
103 
104  // if grid_size or block_size are zero, query occupancy using the CUDA Occupancy API
105  cudaError_t result = cudaOccupancyMaxPotentialBlockSize(
106  &grid_size,
107  &block_size,
108  reinterpret_cast<void const *>(kernel::BlockForEach<Element, Func>));
109 
110  if (result != cudaSuccess) {
111  throw std::runtime_error("Failed to query occupancy.");
112  }
113 
114  // Limit block size. This has the effect of increasing the number of items processed by a
115  // single thread and reduces the impact of initialization overhead.
116  block_size = (block_size < 128 ? block_size : 128);
117  }
118 
119  dim3 grid(grid_size, 1, 1);
120  dim3 block(block_size, 1, 1);
121 
122  kernel::BlockForEach<Element, Func><<< grid, block >>>(ptr, capacity, params);
123  }
124 };
125 
127 
128 } // namespace device
129 } // namespace reference
130 } // namesace cutlass
Definition: aligned_buffer.h:35
+Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
25 #pragma once
26 
27 #include <stdexcept>
28 #include "cutlass/cutlass.h"
30 
31 namespace cutlass {
32 namespace reference {
33 namespace device {
34 
36 
38 template <typename Func, int Rank, typename Params>
39 struct TensorForEach {
40 
42  TensorForEach(Coord<Rank> size, Params params = Params(), int grid_size = 0, int block_size = 0) {
43 
44  if (!grid_size || !block_size) {
45 
46  // if grid_size or block_size are zero, query occupancy using the CUDA Occupancy API
47  cudaError_t result = cudaOccupancyMaxPotentialBlockSize(
48  &grid_size,
49  &block_size,
50  reinterpret_cast<void const *>(kernel::TensorForEach<Func, Rank, Params>));
51 
52  if (result != cudaSuccess) {
53  throw std::runtime_error("Failed to query occupancy.");
54  }
55 
56  // Limit block size. This has the effect of increasing the number of items processed by a
57  // single thread and reduces the impact of initialization overhead.
58  block_size = (block_size < 128 ? block_size : 128);
59  }
60 
61  dim3 grid(grid_size, 1, 1);
62  dim3 block(block_size, 1, 1);
63 
64  kernel::TensorForEach<Func, Rank, Params><<< grid, block >>>(size, params);
65  }
66 };
67 
69 
71 template <typename Func, int Rank, typename Params>
73 
75  TensorDiagonalForEach(Coord<Rank> size, Params params = Params(), int start = 0, int end = -1, int block_size = 128) {
76 
77  if (end < 0) {
78  end = size.min();
79  }
80 
81  dim3 block(block_size, 1, 1);
82  dim3 grid((end - start + block_size - 1) / block_size, 1, 1);
83 
84  kernel::TensorDiagonalForEach<Func, Rank, Params><<< grid, block >>>(size, params, start, end);
85  }
86 };
87 
88 
90 
91 template <typename Element, typename Func>
92 struct BlockForEach {
93 
96  Element *ptr,
97  size_t capacity,
98  typename Func::Params params = typename Func::Params(),
99  int grid_size = 0,
100  int block_size = 0) {
101 
102  if (!grid_size || !block_size) {
103 
104  // if grid_size or block_size are zero, query occupancy using the CUDA Occupancy API
105  cudaError_t result = cudaOccupancyMaxPotentialBlockSize(
106  &grid_size,
107  &block_size,
108  reinterpret_cast<void const *>(kernel::BlockForEach<Element, Func>));
109 
110  if (result != cudaSuccess) {
111  throw std::runtime_error("Failed to query occupancy.");
112  }
113 
114  // Limit block size. This has the effect of increasing the number of items processed by a
115  // single thread and reduces the impact of initialization overhead.
116  block_size = (block_size < 128 ? block_size : 128);
117  }
118 
119  dim3 grid(grid_size, 1, 1);
120  dim3 block(block_size, 1, 1);
121 
122  kernel::BlockForEach<Element, Func><<< grid, block >>>(ptr, capacity, params);
123  }
124 };
125 
127 
128 } // namespace device
129 } // namespace reference
130 } // namespace cutlass
Definition: aligned_buffer.h:35
TensorDiagonalForEach(Coord< Rank > size, Params params=Params(), int start=0, int end=-1, int block_size=128)
Constructor performs the operation.
Definition: device/tensor_foreach.h:75
TensorForEach(Coord< Rank > size, Params params=Params(), int grid_size=0, int block_size=0)
Constructor performs the operation.
Definition: device/tensor_foreach.h:42
Launches a kernel calling a functor for each element along a tensor&#39;s diagonal.
Definition: device/tensor_foreach.h:72
diff --git a/docs/functions_func_s.html b/docs/functions_func_s.html index 785813104b..548b6d97d9 100644 --- a/docs/functions_func_s.html +++ b/docs/functions_func_s.html @@ -141,7 +141,7 @@

- s -

-Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
29 #pragma once
30 
31 #include "cutlass/cutlass.h"
32 #include "cutlass/array.h"
33 #include "cutlass/aligned_buffer.h"
35 
36 #include "cutlass/numeric_types.h"
37 #include "cutlass/matrix_shape.h"
38 
39 #include "cutlass/gemm/gemm.h"
41 
43 
44 namespace cutlass {
45 namespace gemm {
46 namespace threadblock {
47 
49 
51 template <
53  typename Shape_,
55  // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
56  typename IteratorA_,
59  typename SmemIteratorA_,
61  // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
62  typename IteratorB_,
65  typename SmemIteratorB_,
67  typename ElementC_,
69  typename LayoutC_,
71  typename Policy_,
73  typename TransformA_ = NumericArrayConverter<
74  typename SmemIteratorA_::Element,
75  typename IteratorA_::Element,
76  IteratorA_::Fragment::kElements>,
79  typename TransformB_ = NumericArrayConverter<
80  typename SmemIteratorB_::Element,
81  typename IteratorB_::Element,
82  IteratorB_::Fragment::kElements>,
84  typename Enable = bool
85 >
86 class MmaPipelined : public MmaBase<Shape_, Policy_, 2> {
87 public:
88 
91 
92  using Shape = Shape_;
93  using IteratorA = IteratorA_;
94  using IteratorB = IteratorB_;
95  using ElementC = ElementC_;
96  using LayoutC = LayoutC_;
97  using Policy = Policy_;
98 
99  using SmemIteratorA = SmemIteratorA_;
100  using SmemIteratorB = SmemIteratorB_;
101 
102  using TransformA = TransformA_;
103  using TransformB = TransformB_;
104 
105  //
106  // Dependent types
107  //
108 
110  using FragmentA = typename IteratorA::Fragment;
111 
113  using FragmentB = typename IteratorB::Fragment;
114 
116  using FragmentC = typename Policy::Operator::FragmentC;
117 
119  using Operator = typename Policy::Operator;
120 
121  // staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline)
122  static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2");
123 
124 private:
125 
126  using WarpFragmentA = typename Operator::FragmentA;
127  using WarpFragmentB = typename Operator::FragmentB;
128 
129 protected:
130 
133 
136 
137 public:
138 
140  CUTLASS_DEVICE
142  typename Base::SharedStorage &shared_storage,
143  int thread_idx,
144  int warp_idx,
145  int lane_idx
146  ):
147  Base(shared_storage, thread_idx, warp_idx, lane_idx),
148  smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx),
149  smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) {
150 
151  // Compute warp location within threadblock tile by mapping the warp_id to
152  // three coordinates:
153  // _m: the warp's position within the threadblock along the M dimension
154  // _n: the warp's position within the threadblock along the N dimension
155  // _k: the warp's position within the threadblock along the K dimension
156 
157  int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
158  int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
159 
160  int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
161  int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
162 
163  // Add per-warp offsets in units of warp-level tiles
164  this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
165  this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
166  }
167 
169  CUTLASS_DEVICE
171  int gemm_k_iterations,
172  FragmentC &accum,
173  IteratorA iterator_A,
174  IteratorB iterator_B,
175  FragmentC const &src_accum,
176  TransformA transform_A = TransformA(),
177  TransformB transform_B = TransformB()) {
178 
179  //
180  // Prologue
181  //
182 
183  // Perform accumulation in the 'd' output operand
184  accum = src_accum;
185 
186  FragmentA tb_frag_A;
187  FragmentB tb_frag_B;
188 
189  tb_frag_A.clear();
190  tb_frag_B.clear();
191 
192  // The last kblock is loaded in the prolog
193  iterator_A.load(tb_frag_A);
194  iterator_B.load(tb_frag_B);
195 
196  ++iterator_A;
197  ++iterator_B;
198 
199  this->smem_iterator_A_.store(transform_A(tb_frag_A));
200  this->smem_iterator_B_.store(transform_B(tb_frag_B));
201 
202  ++this->smem_iterator_A_;
203  ++this->smem_iterator_B_;
204 
205  __syncthreads();
206 
207  // Pair of fragments used to overlap shared memory loads and math instructions
208  WarpFragmentA warp_frag_A[2];
209  WarpFragmentB warp_frag_B[2];
210 
211  this->warp_tile_iterator_A_.set_kgroup_index(0);
212  this->warp_tile_iterator_B_.set_kgroup_index(0);
213 
214  this->warp_tile_iterator_A_.load(warp_frag_A[0]);
215  this->warp_tile_iterator_B_.load(warp_frag_B[0]);
216 
217  ++this->warp_tile_iterator_A_;
218  ++this->warp_tile_iterator_B_;
219 
220  Operator warp_mma;
221 
222  int smem_write_stage_idx = 1;
223 
224  // Avoid reading out of bounds
225  if (gemm_k_iterations <= 1) {
226  iterator_A.clear_mask();
227  iterator_B.clear_mask();
228  }
229 
230  // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
231  // shared memory loads (which have the tighest latency requirement).
232 
233  //
234  // Mainloop
235  //
236 
237  // Note: The main loop does not support Base::kWarpGemmIterations == 2.
239  for (; gemm_k_iterations > 0; --gemm_k_iterations) {
240  //
241  // Loop over GEMM K dimension
242  //
243 
245  for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) {
246 
247  // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group
248  // as the case may be.
249 
250  if (warp_mma_k == Base::kWarpGemmIterations - 1) {
251 
252  // Write fragments to shared memory
253  this->smem_iterator_A_.store(transform_A(tb_frag_A));
254 
255  this->smem_iterator_B_.store(transform_B(tb_frag_B));
256 
257  __syncthreads();
258 
259  ++this->smem_iterator_B_;
260  ++this->smem_iterator_A_;
261 
262  // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory
263  if (smem_write_stage_idx == 1) {
264  this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
265  this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
266  }
267  else {
268  this->warp_tile_iterator_A_.add_tile_offset(
269  {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
270  this->warp_tile_iterator_B_.add_tile_offset(
271  {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations,
272  0});
273  }
274 
275  smem_write_stage_idx ^= 1;
276  }
277 
278  this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
279  this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
280 
281  this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
282  this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]);
283 
284  ++this->warp_tile_iterator_A_;
285  ++this->warp_tile_iterator_B_;
286 
287  if (warp_mma_k == 0) {
288 
289  iterator_A.load(tb_frag_A);
290  iterator_B.load(tb_frag_B);
291 
292  ++iterator_A;
293  ++iterator_B;
294 
295  // Avoid reading out of bounds if this was the last loop iteration
296  if (gemm_k_iterations <= 2) {
297  iterator_A.clear_mask();
298  iterator_B.clear_mask();
299  }
300  }
301 
302  warp_mma(accum, warp_frag_A[warp_mma_k % 2], warp_frag_B[warp_mma_k % 2], accum);
303  }
304  }
305 
306  }
307 };
308 
310 
311 } // namespace threadblock
312 } // namespace gemm
313 } // namespace cutlass
static int const kM
Definition: include/cutlass/gemm/gemm.h:58
+Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
29 #pragma once
30 
31 #include "cutlass/cutlass.h"
32 #include "cutlass/array.h"
33 #include "cutlass/aligned_buffer.h"
35 
36 #include "cutlass/numeric_types.h"
37 #include "cutlass/matrix_shape.h"
38 
39 #include "cutlass/gemm/gemm.h"
41 
43 
44 namespace cutlass {
45 namespace gemm {
46 namespace threadblock {
47 
49 
51 template <
53  typename Shape_,
55  // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
56  typename IteratorA_,
59  typename SmemIteratorA_,
61  // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator)
62  typename IteratorB_,
65  typename SmemIteratorB_,
67  typename ElementC_,
69  typename LayoutC_,
71  typename Policy_,
73  typename TransformA_ = NumericArrayConverter<
74  typename SmemIteratorA_::Element,
75  typename IteratorA_::Element,
76  IteratorA_::Fragment::kElements>,
79  typename TransformB_ = NumericArrayConverter<
80  typename SmemIteratorB_::Element,
81  typename IteratorB_::Element,
82  IteratorB_::Fragment::kElements>,
84  typename Enable = bool
85 >
86 class MmaPipelined : public MmaBase<Shape_, Policy_, 2> {
87 public:
88 
91 
92  using Shape = Shape_;
93  using IteratorA = IteratorA_;
94  using IteratorB = IteratorB_;
95  using ElementC = ElementC_;
96  using LayoutC = LayoutC_;
97  using Policy = Policy_;
98 
99  using SmemIteratorA = SmemIteratorA_;
100  using SmemIteratorB = SmemIteratorB_;
101 
102  using TransformA = TransformA_;
103  using TransformB = TransformB_;
104 
105  //
106  // Dependent types
107  //
108 
110  using FragmentA = typename IteratorA::Fragment;
111 
113  using FragmentB = typename IteratorB::Fragment;
114 
116  using FragmentC = typename Policy::Operator::FragmentC;
117 
119  using Operator = typename Policy::Operator;
120 
121  // staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline)
122  static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2");
123 
124 private:
125 
126  using WarpFragmentA = typename Operator::FragmentA;
127  using WarpFragmentB = typename Operator::FragmentB;
128 
129 protected:
130 
133 
136 
137 public:
138 
140  CUTLASS_DEVICE
142  typename Base::SharedStorage &shared_storage,
143  int thread_idx,
144  int warp_idx,
145  int lane_idx
146  ):
147  Base(shared_storage, thread_idx, warp_idx, lane_idx),
148  smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx),
149  smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) {
150 
151  // Compute warp location within threadblock tile by mapping the warp_id to
152  // three coordinates:
153  // _m: the warp's position within the threadblock along the M dimension
154  // _n: the warp's position within the threadblock along the N dimension
155  // _k: the warp's position within the threadblock along the K dimension
156 
157  int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN);
158  int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN);
159 
160  int warp_idx_m = warp_idx_mn % Base::WarpCount::kM;
161  int warp_idx_n = warp_idx_mn / Base::WarpCount::kM;
162 
163  // Add per-warp offsets in units of warp-level tiles
164  this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k});
165  this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterations * warp_idx_k, warp_idx_n});
166  }
167 
169  CUTLASS_DEVICE
171  int gemm_k_iterations,
172  FragmentC &accum,
173  IteratorA iterator_A,
174  IteratorB iterator_B,
175  FragmentC const &src_accum,
176  TransformA transform_A = TransformA(),
177  TransformB transform_B = TransformB()) {
178 
179  //
180  // Prologue
181  //
182 
183  // Perform accumulation in the 'd' output operand
184  accum = src_accum;
185 
186  FragmentA tb_frag_A;
187  FragmentB tb_frag_B;
188 
189  tb_frag_A.clear();
190  tb_frag_B.clear();
191 
192  // The last kblock is loaded in the prolog
193  iterator_A.load(tb_frag_A);
194  iterator_B.load(tb_frag_B);
195 
196  ++iterator_A;
197  ++iterator_B;
198 
199  this->smem_iterator_A_.store(transform_A(tb_frag_A));
200  this->smem_iterator_B_.store(transform_B(tb_frag_B));
201 
202  ++this->smem_iterator_A_;
203  ++this->smem_iterator_B_;
204 
205  __syncthreads();
206 
207  // Pair of fragments used to overlap shared memory loads and math instructions
208  WarpFragmentA warp_frag_A[2];
209  WarpFragmentB warp_frag_B[2];
210 
211  this->warp_tile_iterator_A_.set_kgroup_index(0);
212  this->warp_tile_iterator_B_.set_kgroup_index(0);
213 
214  this->warp_tile_iterator_A_.load(warp_frag_A[0]);
215  this->warp_tile_iterator_B_.load(warp_frag_B[0]);
216 
217  ++this->warp_tile_iterator_A_;
218  ++this->warp_tile_iterator_B_;
219 
220  Operator warp_mma;
221 
222  int smem_write_stage_idx = 1;
223 
224  // Avoid reading out of bounds
225  if (gemm_k_iterations <= 1) {
226  iterator_A.clear_mask();
227  iterator_B.clear_mask();
228  }
229 
230  // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing
231  // shared memory loads (which have the tightest latency requirement).
232 
233  //
234  // Mainloop
235  //
236 
237  // Note: The main loop does not support Base::kWarpGemmIterations == 2.
239  for (; gemm_k_iterations > 0; --gemm_k_iterations) {
240  //
241  // Loop over GEMM K dimension
242  //
243 
245  for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) {
246 
247  // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group
248  // as the case may be.
249 
250  if (warp_mma_k == Base::kWarpGemmIterations - 1) {
251 
252  // Write fragments to shared memory
253  this->smem_iterator_A_.store(transform_A(tb_frag_A));
254 
255  this->smem_iterator_B_.store(transform_B(tb_frag_B));
256 
257  __syncthreads();
258 
259  ++this->smem_iterator_B_;
260  ++this->smem_iterator_A_;
261 
262  // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory
263  if (smem_write_stage_idx == 1) {
264  this->smem_iterator_A_.add_tile_offset({0, -Base::kStages});
265  this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0});
266  }
267  else {
268  this->warp_tile_iterator_A_.add_tile_offset(
269  {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations});
270  this->warp_tile_iterator_B_.add_tile_offset(
271  {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations,
272  0});
273  }
274 
275  smem_write_stage_idx ^= 1;
276  }
277 
278  this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
279  this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations);
280 
281  this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]);
282  this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]);
283 
284  ++this->warp_tile_iterator_A_;
285  ++this->warp_tile_iterator_B_;
286 
287  if (warp_mma_k == 0) {
288 
289  iterator_A.load(tb_frag_A);
290  iterator_B.load(tb_frag_B);
291 
292  ++iterator_A;
293  ++iterator_B;
294 
295  // Avoid reading out of bounds if this was the last loop iteration
296  if (gemm_k_iterations <= 2) {
297  iterator_A.clear_mask();
298  iterator_B.clear_mask();
299  }
300  }
301 
302  warp_mma(accum, warp_frag_A[warp_mma_k % 2], warp_frag_B[warp_mma_k % 2], accum);
303  }
304  }
305 
306  }
307 };
308 
310 
311 } // namespace threadblock
312 } // namespace gemm
313 } // namespace cutlass
static int const kM
Definition: include/cutlass/gemm/gemm.h:58
LayoutC_ LayoutC
Layout of accumulator matrix.
Definition: mma_pipelined.h:96
TransformB_ TransformB
Definition: mma_pipelined.h:103
Definition: aligned_buffer.h:35
diff --git a/docs/namespacecutlass_1_1gemm_1_1device.html b/docs/namespacecutlass_1_1gemm_1_1device.html index 7023f4f88a..c0b27cbe97 100644 --- a/docs/namespacecutlass_1_1gemm_1_1device.html +++ b/docs/namespacecutlass_1_1gemm_1_1device.html @@ -134,17 +134,17 @@ class  Gemm   class  Gemm< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, SplitKSerial, Operator_, IsBetaZero > - Parital specialization for column-major output exchanges problem size and operand. More...
+ Partial specialization for column-major output exchanges problem size and operand. More...
  class  GemmBatched   class  GemmBatched< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, AlignmentA, AlignmentB, Operator_ > - Parital specialization for column-major output exchanges problem size and operand. More...
+ Partial specialization for column-major output exchanges problem size and operand. More...
  class  GemmComplex   class  GemmComplex< ElementA_, LayoutA_, ElementB_, LayoutB_, ElementC_, layout::ColumnMajor, ElementAccumulator_, OperatorClass_, ArchTag_, ThreadblockShape_, WarpShape_, InstructionShape_, EpilogueOutputOp_, ThreadblockSwizzle_, Stages, TransformA, TransformB, SplitKSerial > - Parital specialization for column-major output exchanges problem size and operand. More...
+ Partial specialization for column-major output exchanges problem size and operand. More...
  class  GemmSplitKParallel   diff --git a/docs/namespacecutlass_1_1reference_1_1device.html b/docs/namespacecutlass_1_1reference_1_1device.html index 86f21a00bb..54f5009f41 100644 --- a/docs/namespacecutlass_1_1reference_1_1device.html +++ b/docs/namespacecutlass_1_1reference_1_1device.html @@ -125,7 +125,7 @@  Partial specialization for multiply-add-saturate. More...
  struct  Gemm< ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType, AccumulatorType, arch::OpXorPopc > - Parital specialization for XOR-popc. More...
+ Partial specialization for XOR-popc. More...
  struct  TensorDiagonalForEach  Launches a kernel calling a functor for each element along a tensor's diagonal. More...
@@ -183,7 +183,7 @@   template<typename Element , typename Layout > void TensorFillIdentity (TensorView< Element, Layout > view) - Fills a tensor's digonal with 1 and 0 everywhere else. More...
+ Fills a tensor's diagonal with 1 and 0 everywhere else. More...
  template<typename Element , typename Layout > void TensorUpdateDiagonal (TensorView< Element, Layout > view, Element diag=Element(1)) diff --git a/docs/namespacecutlass_1_1reference_1_1host.html b/docs/namespacecutlass_1_1reference_1_1host.html index 90f9a01442..194bdfe88d 100644 --- a/docs/namespacecutlass_1_1reference_1_1host.html +++ b/docs/namespacecutlass_1_1reference_1_1host.html @@ -122,7 +122,7 @@  Partial specialization for multiply-add-saturate. More...
  struct  Gemm< ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType, ComputeType, arch::OpXorPopc > - Parital specialization for XOR-popc. More...
+ Partial specialization for XOR-popc. More...
  - + @@ -1677,7 +1677,7 @@

Function Documentation

@@ -247,7 +247,7 @@

 
template<typename Element , typename Layout >
void TensorFillIdentity (TensorView< Element, Layout > dst)
 Helper to fill a tensor's digonal with 1 and 0 everywhere else. More...
 Helper to fill a tensor's diagonal with 1 and 0 everywhere else. More...
 
template<typename Element , typename Layout >
void TensorUpdateDiagonal (TensorView< Element, Layout > dst, Element val=Element(1))
-

Returns a pair containing a boolean of whether a value exists in a tensor and the location of of the first occurrence. If the value is not contained in the tensor, the second element of the pair is undefined.

+

Returns a pair containing a boolean of whether a value exists in a tensor and the location of the first occurrence. If the value is not contained in the tensor, the second element of the pair is undefined.

diff --git a/docs/namespacecutlass_1_1transform.html b/docs/namespacecutlass_1_1transform.html index 974fe32468..c8eb5ba164 100644 --- a/docs/namespacecutlass_1_1transform.html +++ b/docs/namespacecutlass_1_1transform.html @@ -128,7 +128,7 @@ struct  TransposePitchLinearThreadMap   struct  TransposePitchLinearThreadMap2DThreadTile - Thread Mapping a 2D threadtiled mapping as a tranposed Pitchlinear2DThreadTile mapping. More...
+ Thread Mapping a 2D threadtiled mapping as a transposed Pitchlinear2DThreadTile mapping. More...
  struct  TransposePitchLinearThreadMapSimt   diff --git a/docs/pitch__linear__thread__map_8h.html b/docs/pitch__linear__thread__map_8h.html index 13fef0cc77..0266464109 100644 --- a/docs/pitch__linear__thread__map_8h.html +++ b/docs/pitch__linear__thread__map_8h.html @@ -164,7 +164,7 @@  Internal implementation details. More...
  struct  cutlass::transform::TransposePitchLinearThreadMap2DThreadTile< ThreadMap_ > - Thread Mapping a 2D threadtiled mapping as a tranposed Pitchlinear2DThreadTile mapping. More...
+ Thread Mapping a 2D threadtiled mapping as a transposed Pitchlinear2DThreadTile mapping. More...
  - + diff --git a/docs/structcutlass_1_1CommandLine.html b/docs/structcutlass_1_1CommandLine.html index 0bde0ec4ad..01cd35a270 100644 --- a/docs/structcutlass_1_1CommandLine.html +++ b/docs/structcutlass_1_1CommandLine.html @@ -151,7 +151,7 @@ - +

diff --git a/docs/pitch__linear__thread__map_8h_source.html b/docs/pitch__linear__thread__map_8h_source.html index fcdf70ae10..9f18269e15 100644 --- a/docs/pitch__linear__thread__map_8h_source.html +++ b/docs/pitch__linear__thread__map_8h_source.html @@ -129,7 +129,7 @@
Definition: pitch_linear_thread_map.h:491
static CUTLASS_HOST_DEVICE TensorCoord initial_offset(int thread_id)
Definition: pitch_linear_thread_map.h:187
WarpThreadArrangement_ WarpThreadArrangement
Fixed arrangement of threads within a warp (units of threads).
Definition: pitch_linear_thread_map.h:226
-
Thread Mapping a 2D threadtiled mapping as a tranposed Pitchlinear2DThreadTile mapping.
Definition: pitch_linear_thread_map.h:713
+
Thread Mapping a 2D threadtiled mapping as a transposed Pitchlinear2DThreadTile mapping.
Definition: pitch_linear_thread_map.h:713
WarpThreadArrangement_ WarpThreadArrangement
Fixed arrangement of threads within a warp (units of threads).
Definition: pitch_linear_thread_map.h:355
Internal details made public to facilitate introspection Iterations along each dimension (concept: Pi...
Definition: pitch_linear_thread_map.h:353
Definition: pitch_linear_thread_map.h:205
diff --git a/docs/search/all_12.js b/docs/search/all_12.js index 0e091040bc..c9f8a45c70 100644 --- a/docs/search/all_12.js +++ b/docs/search/all_12.js @@ -14,7 +14,7 @@ var searchData= ['semaphore',['Semaphore',['../classcutlass_1_1Semaphore.html',1,'cutlass']]], ['semaphore',['Semaphore',['../classcutlass_1_1Semaphore.html#a2ce4cd07fe773efa429f726cfbd98070',1,'cutlass::Semaphore::Semaphore()'],['../structcutlass_1_1gemm_1_1kernel_1_1Gemm_1_1Params.html#adec6d0c6d74e7f456196f453e302fbbb',1,'cutlass::gemm::kernel::Gemm::Params::semaphore()']]], ['semaphore_2eh',['semaphore.h',['../semaphore_8h.html',1,'']]], - ['seperate_5fstring',['seperate_string',['../structcutlass_1_1CommandLine.html#a5f86e4b2bd8c44b739c83530d77c5590',1,'cutlass::CommandLine']]], + ['separate_5fstring',['separate_string',['../structcutlass_1_1CommandLine.html#a5f86e4b2bd8c44b739c83530d77c5590',1,'cutlass::CommandLine']]], ['sequential',['sequential',['../structcutlass_1_1Distribution.html#ab86d975567ef141ff82067b1f41cd3ee',1,'cutlass::Distribution::sequential()'],['../structcutlass_1_1Distribution.html#a499f4023e0d42356ce71d38cc32bf92aa39d3cf55e90573c8d1dfb483cfb410dc',1,'cutlass::Distribution::Sequential()']]], ['set',['set',['../classcutlass_1_1PredicateVector_1_1Iterator.html#aadfd039b5622098c9e46706a27122575',1,'cutlass::PredicateVector::Iterator::set()'],['../structcutlass_1_1PredicateVector.html#a062fa8a8df725ef08ced2ffcca8336af',1,'cutlass::PredicateVector::set()'],['../classcutlass_1_1SubbyteReference.html#a6473e57520d8ee7afbd95c1e1641e05a',1,'cutlass::SubbyteReference::set()']]], ['set_5fgaussian',['set_gaussian',['../structcutlass_1_1Distribution.html#ad594b5ec1d577e8ef03d4d808a8220b1',1,'cutlass::Distribution']]], diff --git a/docs/search/functions_12.js b/docs/search/functions_12.js index f2b3bff9cf..6648b431e4 100644 --- a/docs/search/functions_12.js +++ b/docs/search/functions_12.js @@ -3,7 +3,7 @@ var searchData= ['scalar_5fop',['scalar_op',['../structcutlass_1_1minimum_3_01Array_3_01T_00_01N_01_4_01_4.html#a4b42227184cb7c796460062c46a84b57',1,'cutlass::minimum< Array< T, N > >']]], ['scalario',['ScalarIO',['../structcutlass_1_1ScalarIO.html#ad4166575521254088bf6c6300c351714',1,'cutlass::ScalarIO::ScalarIO()'],['../structcutlass_1_1ScalarIO.html#a5227e1e9ed24326ad4f8dc94d186186f',1,'cutlass::ScalarIO::ScalarIO(T value)']]], ['semaphore',['Semaphore',['../classcutlass_1_1Semaphore.html#a2ce4cd07fe773efa429f726cfbd98070',1,'cutlass::Semaphore']]], - ['seperate_5fstring',['seperate_string',['../structcutlass_1_1CommandLine.html#a5f86e4b2bd8c44b739c83530d77c5590',1,'cutlass::CommandLine']]], + ['separate_5fstring',['separate_string',['../structcutlass_1_1CommandLine.html#a5f86e4b2bd8c44b739c83530d77c5590',1,'cutlass::CommandLine']]], ['set',['set',['../classcutlass_1_1PredicateVector_1_1Iterator.html#aadfd039b5622098c9e46706a27122575',1,'cutlass::PredicateVector::Iterator::set()'],['../structcutlass_1_1PredicateVector.html#a062fa8a8df725ef08ced2ffcca8336af',1,'cutlass::PredicateVector::set()'],['../classcutlass_1_1SubbyteReference.html#a6473e57520d8ee7afbd95c1e1641e05a',1,'cutlass::SubbyteReference::set()']]], ['set_5fgaussian',['set_gaussian',['../structcutlass_1_1Distribution.html#ad594b5ec1d577e8ef03d4d808a8220b1',1,'cutlass::Distribution']]], ['set_5fidentity',['set_identity',['../structcutlass_1_1Distribution.html#aad2cf02af3d520544d89843cc4295858',1,'cutlass::Distribution']]], diff --git a/docs/structcutlass_1_1CommandLine-members.html b/docs/structcutlass_1_1CommandLine-members.html index 77668951c3..6a17b2f0f5 100644 --- a/docs/structcutlass_1_1CommandLine-members.html +++ b/docs/structcutlass_1_1CommandLine-members.html @@ -115,7 +115,7 @@

keyscutlass::CommandLine
num_naked_args() const cutlass::CommandLineinline
parsed_argc() const cutlass::CommandLineinline
seperate_string(std::string const &str, std::vector< value_t > &vals, char sep= ',')cutlass::CommandLineinlinestatic
separate_string(std::string const &str, std::vector< value_t > &vals, char sep= ',')cutlass::CommandLineinlinestatic
tokenize(std::vector< std::pair< std::string, std::string > > &tokens, std::string const &str, char delim= ',', char sep= ':')cutlass::CommandLineinlinestatic
tokenize(std::vector< std::string > &tokens, std::string const &str, char delim= ',', char sep= ':')cutlass::CommandLineinlinestatic
valuescutlass::CommandLine
 Tokenizes a comma-delimited list of string pairs delimited by ':'. More...
 
template<typename value_t >
static void seperate_string (std::string const &str, std::vector< value_t > &vals, char sep= ',')
static void separate_string (std::string const &str, std::vector< value_t > &vals, char sep= ',')
 
- +

@@ -548,7 +548,7 @@

Member Function Documentation

- + diff --git a/docs/structcutlass_1_1reference_1_1device_1_1Gemm_3_01ElementA_00_01LayoutA_00_01ElementB_00_01Layout660562b232f408218828ca5915b7e73a.html b/docs/structcutlass_1_1reference_1_1device_1_1Gemm_3_01ElementA_00_01LayoutA_00_01ElementB_00_01Layout660562b232f408218828ca5915b7e73a.html index 37cb3e5ddc..2f4bf08eac 100644 --- a/docs/structcutlass_1_1reference_1_1device_1_1Gemm_3_01ElementA_00_01LayoutA_00_01ElementB_00_01Layout660562b232f408218828ca5915b7e73a.html +++ b/docs/structcutlass_1_1reference_1_1device_1_1Gemm_3_01ElementA_00_01LayoutA_00_01ElementB_00_01Layout660562b232f408218828ca5915b7e73a.html @@ -104,7 +104,7 @@
-

Parital specialization for XOR-popc. +

Partial specialization for XOR-popc.

#include <gemm.h>

diff --git a/docs/structcutlass_1_1reference_1_1device_1_1kernel_1_1detail_1_1TensorForEachHelper_3_01Func_00_01Rank_00_010_01_4.html b/docs/structcutlass_1_1reference_1_1device_1_1kernel_1_1detail_1_1TensorForEachHelper_3_01Func_00_01Rank_00_010_01_4.html index 2c89af687f..2daeadcc77 100644 --- a/docs/structcutlass_1_1reference_1_1device_1_1kernel_1_1detail_1_1TensorForEachHelper_3_01Func_00_01Rank_00_010_01_4.html +++ b/docs/structcutlass_1_1reference_1_1device_1_1kernel_1_1detail_1_1TensorForEachHelper_3_01Func_00_01Rank_00_010_01_4.html @@ -112,7 +112,7 @@
- +
static void cutlass::CommandLine::seperate_string static void cutlass::CommandLine::separate_string ( std::string const &  str,

Public Member Functions

__inline__ __device__ TensorForEachHelper (Func &func, Coord< Rank > const &size, Coord< Rank > &coord, int64_t index)
 Constructor for fastest chaning rank. More...
 Constructor for fastest changing rank. More...
 

Constructor & Destructor Documentation

diff --git a/docs/structcutlass_1_1reference_1_1host_1_1Gemm_3_01ElementA_00_01LayoutA_00_01ElementB_00_01LayoutB_4f3f32c4b336238abfd741e87bfced46.html b/docs/structcutlass_1_1reference_1_1host_1_1Gemm_3_01ElementA_00_01LayoutA_00_01ElementB_00_01LayoutB_4f3f32c4b336238abfd741e87bfced46.html index 0840df5980..7a83a97d8c 100644 --- a/docs/structcutlass_1_1reference_1_1host_1_1Gemm_3_01ElementA_00_01LayoutA_00_01ElementB_00_01LayoutB_4f3f32c4b336238abfd741e87bfced46.html +++ b/docs/structcutlass_1_1reference_1_1host_1_1Gemm_3_01ElementA_00_01LayoutA_00_01ElementB_00_01LayoutB_4f3f32c4b336238abfd741e87bfced46.html @@ -104,7 +104,7 @@
-

Parital specialization for XOR-popc. +

Partial specialization for XOR-popc.

#include <gemm.h>

diff --git a/docs/structcutlass_1_1reference_1_1host_1_1detail_1_1TensorForEachHelper_3_01Func_00_01Rank_00_010_01_4.html b/docs/structcutlass_1_1reference_1_1host_1_1detail_1_1TensorForEachHelper_3_01Func_00_01Rank_00_010_01_4.html index 2e440e45ad..6c63b40e38 100644 --- a/docs/structcutlass_1_1reference_1_1host_1_1detail_1_1TensorForEachHelper_3_01Func_00_01Rank_00_010_01_4.html +++ b/docs/structcutlass_1_1reference_1_1host_1_1detail_1_1TensorForEachHelper_3_01Func_00_01Rank_00_010_01_4.html @@ -113,7 +113,7 @@

Public Member Functions

 TensorForEachHelper (Func &func, Coord< Rank > const &extent, Coord< Rank > &coord)
 Constructor for fastest chaning rank. More...
 Constructor for fastest changing rank. More...
 
- +

diff --git a/docs/structcutlass_1_1transform_1_1TransposePitchLinearThreadMap2DThreadTile.html b/docs/structcutlass_1_1transform_1_1TransposePitchLinearThreadMap2DThreadTile.html index 41e0af2623..bc5294f7c1 100644 --- a/docs/structcutlass_1_1transform_1_1TransposePitchLinearThreadMap2DThreadTile.html +++ b/docs/structcutlass_1_1transform_1_1TransposePitchLinearThreadMap2DThreadTile.html @@ -106,7 +106,7 @@
-

Thread Mapping a 2D threadtiled mapping as a tranposed Pitchlinear2DThreadTile mapping. +

Thread Mapping a 2D threadtiled mapping as a transposed Pitchlinear2DThreadTile mapping.

#include <pitch_linear_thread_map.h>

diff --git a/docs/tools_2util_2include_2cutlass_2util_2reference_2device_2gemm_8h.html b/docs/tools_2util_2include_2cutlass_2util_2reference_2device_2gemm_8h.html index 2a0a978e79..cc75285506 100644 --- a/docs/tools_2util_2include_2cutlass_2util_2reference_2device_2gemm_8h.html +++ b/docs/tools_2util_2include_2cutlass_2util_2reference_2device_2gemm_8h.html @@ -134,7 +134,7 @@

 Partial specialization for multiply-add-saturate. More...
 
struct  cutlass::reference::device::Gemm< ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType, AccumulatorType, arch::OpXorPopc >
 Parital specialization for XOR-popc. More...
 Partial specialization for XOR-popc. More...
 
- +

diff --git a/docs/tools_2util_2include_2cutlass_2util_2reference_2host_2gemm_8h.html b/docs/tools_2util_2include_2cutlass_2util_2reference_2host_2gemm_8h.html index d20a078440..b0bfdbc283 100644 --- a/docs/tools_2util_2include_2cutlass_2util_2reference_2host_2gemm_8h.html +++ b/docs/tools_2util_2include_2cutlass_2util_2reference_2host_2gemm_8h.html @@ -141,7 +141,7 @@

 Partial specialization for multiply-add-saturate. More...
 
struct  cutlass::reference::host::Gemm< ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, ScalarType, ComputeType, arch::OpXorPopc >
 Parital specialization for XOR-popc. More...
 Partial specialization for XOR-popc. More...
 

diff --git a/docs/wmma__sm75_8h_source.html b/docs/wmma__sm75_8h_source.html index 72ad72f91f..6ff6405d34 100644 --- a/docs/wmma__sm75_8h_source.html +++ b/docs/wmma__sm75_8h_source.html @@ -98,7 +98,7 @@
wmma_sm75.h
-Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
29 #pragma once
30 
31 #include <assert.h>
32 #include "cutlass/layout/matrix.h"
33 
35 namespace cutlass {
36 namespace arch {
37 
39 //
40 // WMMA template structure defines nvcuda::wmma::fragments and static assert for
41 // wmma native instruction sizes supported for cutlass::int4b_t (experimental::s4).
42 //
44 template <
45 typename Shape_,
46 typename LayoutA_,
47 typename LayoutB_,
48 typename LayoutC_>
49 struct Wmma<
50  Shape_,
51  cutlass::int4b_t,
52  LayoutA_,
54  LayoutB_,
55  int32_t,
56  LayoutC_,
57  cutlass::arch::OpMultiplyAdd
58 > {
59 #if defined(CUTLASS_ARCH_WMMA_SM75_ENABLED)
60  using Shape = Shape_;
61  using ElementA = cutlass::int4b_t;
62  using LayoutA = LayoutA_;
63  using ElementB = cutlass::int4b_t;
64  using LayoutB = LayoutB_;
65  using ElementC = int32_t;
66  using LayoutC = LayoutC_;
67  using Operator = cutlass::arch::OpMultiplyAdd;
68 
69  // check supported wmma shape for the given multiplicand data types
72  "Supported list of wmma operator shape for s8 multiplicands is: 8x8x32");
73 
74 
75  // Wmma Fragment
76  using FragmentA = nvcuda::wmma::fragment<
77  nvcuda::wmma::matrix_a,
78  Shape::kM,
79  Shape::kN,
80  Shape::kK,
81  typename CutlassToWmmaDataType<ElementA>::Type,
82  typename CutlassToWmmaLayout<LayoutA>::Layout>;
83 
84  using FragmentB = nvcuda::wmma::fragment<
85  nvcuda::wmma::matrix_b,
86  Shape::kM,
87  Shape::kN,
88  Shape::kK,
89  typename CutlassToWmmaDataType<ElementB>::Type,
90  typename CutlassToWmmaLayout<LayoutB>::Layout>;
91 
92  using FragmentC = nvcuda::wmma::fragment<
93  nvcuda::wmma::accumulator,
94  Shape::kM,
95  Shape::kN,
96  Shape::kK,
97  typename CutlassToWmmaDataType<ElementC>::Type>;
98 
100  CUTLASS_DEVICE
101  void operator()(
102  FragmentC &D,
103  FragmentA const &A,
104  FragmentB const &B,
105  FragmentC const &C) const {
106  nvcuda::wmma::mma_sync(D, A, B, C);
107  }
108 
109 #else
110  static_assert(false, "wmma.mma.sync interger type multiplicands is avialable only for SM75 and beyond");
111 #endif
112 
113 };
114 
116 //
117 // WMMA template structure defines nvcuda::wmma::fragments and static assert for
118 // wmma native instruction sizes supported for cutlass::uint1b_t (experimental::b1)
119 // (nvcuda::wmma targetting SASS instruction BMMA)
120 //
122 template <
123 typename Shape_,
124 typename LayoutA_,
125 typename LayoutB_,
126 typename LayoutC_>
127 struct Wmma<
128  Shape_,
129  cutlass::uint1b_t,
130  LayoutA_,
132  LayoutB_,
133  int32_t,
134  LayoutC_,
135  cutlass::arch::OpXorPopc
136 > {
137 #if defined(CUTLASS_ARCH_WMMA_SM75_ENABLED)
138  using Shape = Shape_;
139  using ElementA = cutlass::uint1b_t;
140  using LayoutA = LayoutA_;
141  using ElementB = cutlass::uint1b_t;
142  using LayoutB = LayoutB_;
143  using ElementC = int32_t;
144  using LayoutC = LayoutC_;
145  using Operator = cutlass::arch::OpXorPopc;
146 
147  // check supported wmma shape for the given multiplicand data types
150  "Supported list of wmma operator shape for b1 multiplicands is: 8x8x128");
151 
152 
153  // Wmma Fragment
154  using FragmentA = nvcuda::wmma::fragment<
155  nvcuda::wmma::matrix_a,
156  Shape::kM,
157  Shape::kN,
158  Shape::kK,
159  typename CutlassToWmmaDataType<ElementA>::Type,
160  typename CutlassToWmmaLayout<LayoutA>::Layout>;
161 
162  using FragmentB = nvcuda::wmma::fragment<
163  nvcuda::wmma::matrix_b,
164  Shape::kM,
165  Shape::kN,
166  Shape::kK,
167  typename CutlassToWmmaDataType<ElementB>::Type,
168  typename CutlassToWmmaLayout<LayoutB>::Layout>;
169 
170  using FragmentC = nvcuda::wmma::fragment<
171  nvcuda::wmma::accumulator,
172  Shape::kM,
173  Shape::kN,
174  Shape::kK,
175  typename CutlassToWmmaDataType<ElementC>::Type>;
176 
178  CUTLASS_DEVICE
179  void operator()(
180  FragmentC &D,
181  FragmentA const &A,
182  FragmentB const &B,
183  FragmentC const &C) const {
184 
185  nvcuda::wmma::bmma_sync(D, A, B, C, nvcuda::wmma::experimental::bmmaBitOpXOR,
186  nvcuda::wmma::experimental::bmmaAccumulateOpPOPC);
187  }
188 
189 #else
190  static_assert(false, "wmma.mma.sync interger type multiplicands is avialable only for SM75 and beyond");
191 #endif
192 
193 };
194 
195 } // namespace arch
196 } // namespace cutlass
Definition: aligned_buffer.h:35
+Go to the documentation of this file.
1 /***************************************************************************************************
2  * Copyright (c) 2017-2019, NVIDIA CORPORATION. All rights reserved.
3  *
4  * Redistribution and use in source and binary forms, with or without modification, are permitted
5  * provided that the following conditions are met:
6  * * Redistributions of source code must retain the above copyright notice, this list of
7  * conditions and the following disclaimer.
8  * * Redistributions in binary form must reproduce the above copyright notice, this list of
9  * conditions and the following disclaimer in the documentation and/or other materials
10  * provided with the distribution.
11  * * Neither the name of the NVIDIA CORPORATION nor the names of its contributors may be used
12  * to endorse or promote products derived from this software without specific prior written
13  * permission.
14  *
15  * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND ANY EXPRESS OR
16  * IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND
17  * FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL NVIDIA CORPORATION BE LIABLE
18  * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
19  * BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS;
20  * OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT,
21  * STRICT LIABILITY, OR TOR (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
22  * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
23  *
24  **************************************************************************************************/
29 #pragma once
30 
31 #include <assert.h>
32 #include "cutlass/layout/matrix.h"
33 
35 namespace cutlass {
36 namespace arch {
37 
39 //
40 // WMMA template structure defines nvcuda::wmma::fragments and static assert for
41 // wmma native instruction sizes supported for cutlass::int4b_t (experimental::s4).
42 //
44 template <
45 typename Shape_,
46 typename LayoutA_,
47 typename LayoutB_,
48 typename LayoutC_>
49 struct Wmma<
50  Shape_,
51  cutlass::int4b_t,
52  LayoutA_,
54  LayoutB_,
55  int32_t,
56  LayoutC_,
57  cutlass::arch::OpMultiplyAdd
58 > {
59 #if defined(CUTLASS_ARCH_WMMA_SM75_ENABLED)
60  using Shape = Shape_;
61  using ElementA = cutlass::int4b_t;
62  using LayoutA = LayoutA_;
63  using ElementB = cutlass::int4b_t;
64  using LayoutB = LayoutB_;
65  using ElementC = int32_t;
66  using LayoutC = LayoutC_;
67  using Operator = cutlass::arch::OpMultiplyAdd;
68 
69  // check supported wmma shape for the given multiplicand data types
72  "Supported list of wmma operator shape for s8 multiplicands is: 8x8x32");
73 
74 
75  // Wmma Fragment
76  using FragmentA = nvcuda::wmma::fragment<
77  nvcuda::wmma::matrix_a,
78  Shape::kM,
79  Shape::kN,
80  Shape::kK,
81  typename CutlassToWmmaDataType<ElementA>::Type,
82  typename CutlassToWmmaLayout<LayoutA>::Layout>;
83 
84  using FragmentB = nvcuda::wmma::fragment<
85  nvcuda::wmma::matrix_b,
86  Shape::kM,
87  Shape::kN,
88  Shape::kK,
89  typename CutlassToWmmaDataType<ElementB>::Type,
90  typename CutlassToWmmaLayout<LayoutB>::Layout>;
91 
92  using FragmentC = nvcuda::wmma::fragment<
93  nvcuda::wmma::accumulator,
94  Shape::kM,
95  Shape::kN,
96  Shape::kK,
97  typename CutlassToWmmaDataType<ElementC>::Type>;
98 
100  CUTLASS_DEVICE
101  void operator()(
102  FragmentC &D,
103  FragmentA const &A,
104  FragmentB const &B,
105  FragmentC const &C) const {
106  nvcuda::wmma::mma_sync(D, A, B, C);
107  }
108 
109 #else
110  static_assert(false, "wmma.mma.sync interger type multiplicands is avialable only for SM75 and beyond");
111 #endif
112 
113 };
114 
116 //
117 // WMMA template structure defines nvcuda::wmma::fragments and static assert for
118 // wmma native instruction sizes supported for cutlass::uint1b_t (experimental::b1)
119 // (nvcuda::wmma targeting SASS instruction BMMA)
120 //
122 template <
123 typename Shape_,
124 typename LayoutA_,
125 typename LayoutB_,
126 typename LayoutC_>
127 struct Wmma<
128  Shape_,
129  cutlass::uint1b_t,
130  LayoutA_,
132  LayoutB_,
133  int32_t,
134  LayoutC_,
135  cutlass::arch::OpXorPopc
136 > {
137 #if defined(CUTLASS_ARCH_WMMA_SM75_ENABLED)
138  using Shape = Shape_;
139  using ElementA = cutlass::uint1b_t;
140  using LayoutA = LayoutA_;
141  using ElementB = cutlass::uint1b_t;
142  using LayoutB = LayoutB_;
143  using ElementC = int32_t;
144  using LayoutC = LayoutC_;
145  using Operator = cutlass::arch::OpXorPopc;
146 
147  // check supported wmma shape for the given multiplicand data types
150  "Supported list of wmma operator shape for b1 multiplicands is: 8x8x128");
151 
152 
153  // Wmma Fragment
154  using FragmentA = nvcuda::wmma::fragment<
155  nvcuda::wmma::matrix_a,
156  Shape::kM,
157  Shape::kN,
158  Shape::kK,
159  typename CutlassToWmmaDataType<ElementA>::Type,
160  typename CutlassToWmmaLayout<LayoutA>::Layout>;
161 
162  using FragmentB = nvcuda::wmma::fragment<
163  nvcuda::wmma::matrix_b,
164  Shape::kM,
165  Shape::kN,
166  Shape::kK,
167  typename CutlassToWmmaDataType<ElementB>::Type,
168  typename CutlassToWmmaLayout<LayoutB>::Layout>;
169 
170  using FragmentC = nvcuda::wmma::fragment<
171  nvcuda::wmma::accumulator,
172  Shape::kM,
173  Shape::kN,
174  Shape::kK,
175  typename CutlassToWmmaDataType<ElementC>::Type>;
176 
178  CUTLASS_DEVICE
179  void operator()(
180  FragmentC &D,
181  FragmentA const &A,
182  FragmentB const &B,
183  FragmentC const &C) const {
184 
185  nvcuda::wmma::bmma_sync(D, A, B, C, nvcuda::wmma::experimental::bmmaBitOpXOR,
186  nvcuda::wmma::experimental::bmmaAccumulateOpPOPC);
187  }
188 
189 #else
190  static_assert(false, "wmma.mma.sync interger type multiplicands is avialable only for SM75 and beyond");
191 #endif
192 
193 };
194 
195 } // namespace arch
196 } // namespace cutlass
Definition: aligned_buffer.h:35
std::is_same (false specialization)
Definition: platform.h:394
integer_subbyte< 1, false > uint1b_t
1-bit Unsigned integer type
Definition: integer_subbyte.h:152
4-bit signed integer type
Definition: integer_subbyte.h:42
diff --git a/examples/00_basic_gemm/CMakeLists.txt b/examples/00_basic_gemm/CMakeLists.txt index 5af8fcf363..9002aad943 100644 --- a/examples/00_basic_gemm/CMakeLists.txt +++ b/examples/00_basic_gemm/CMakeLists.txt @@ -1,5 +1,5 @@ -# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without diff --git a/examples/00_basic_gemm/basic_gemm.cu b/examples/00_basic_gemm/basic_gemm.cu index 7c633b30a5..c867112fc2 100644 --- a/examples/00_basic_gemm/basic_gemm.cu +++ b/examples/00_basic_gemm/basic_gemm.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -47,7 +47,7 @@ or utilities within CUTLASS. Such utilities are demonstrated elsewhere in other examples and are prevalent in the CUTLASS unit tests. - This example has delibrately been kept similar to the basic_gemm example from cutass-1.3 to + This example has delibrately been kept similar to the basic_gemm example from cutlass-1.3 to highlight the minimum amount of differences needed to transition to cutlass-2.0. Cutlass-1.3 sgemm: https://github.com/NVIDIA/cutlass/blob/master/examples/00_basic_gemm/basic_gemm.cu diff --git a/examples/01_cutlass_utilities/CMakeLists.txt b/examples/01_cutlass_utilities/CMakeLists.txt index 5673303843..bf37d18a84 100644 --- a/examples/01_cutlass_utilities/CMakeLists.txt +++ b/examples/01_cutlass_utilities/CMakeLists.txt @@ -1,5 +1,5 @@ -# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without diff --git a/examples/01_cutlass_utilities/cutlass_utilities.cu b/examples/01_cutlass_utilities/cutlass_utilities.cu index 30ec28412d..43a3d46d99 100644 --- a/examples/01_cutlass_utilities/cutlass_utilities.cu +++ b/examples/01_cutlass_utilities/cutlass_utilities.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/examples/02_dump_reg_shmem/CMakeLists.txt b/examples/02_dump_reg_shmem/CMakeLists.txt index 9dd94ab4a5..0216f2b480 100644 --- a/examples/02_dump_reg_shmem/CMakeLists.txt +++ b/examples/02_dump_reg_shmem/CMakeLists.txt @@ -1,5 +1,5 @@ -# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without @@ -31,4 +31,5 @@ cutlass_example_add_executable( 02_dump_reg_shmem dump_reg_shmem.cu + DISABLE_TESTS ON ) diff --git a/examples/02_dump_reg_shmem/dump_reg_shmem.cu b/examples/02_dump_reg_shmem/dump_reg_shmem.cu index 159b0b4924..3db7821ffa 100644 --- a/examples/02_dump_reg_shmem/dump_reg_shmem.cu +++ b/examples/02_dump_reg_shmem/dump_reg_shmem.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/examples/03_visualize_layout/CMakeLists.txt b/examples/03_visualize_layout/CMakeLists.txt index 27c38249ec..be8c7436fa 100644 --- a/examples/03_visualize_layout/CMakeLists.txt +++ b/examples/03_visualize_layout/CMakeLists.txt @@ -1,5 +1,5 @@ -# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without @@ -29,7 +29,6 @@ set(TEST_COMMAND_00 RowMajor --extent=16,16) -set(TEST_COMMAND_01 \"ColumnMajorInterleaved<4>\" --extent=32,8 --output-shape=16 --vectorize=4) cutlass_example_add_executable( 03_visualize_layout @@ -37,6 +36,5 @@ cutlass_example_add_executable( register_layout.cu TEST_COMMAND_OPTIONS TEST_COMMAND_00 - TEST_COMMAND_01 ) diff --git a/examples/03_visualize_layout/options.h b/examples/03_visualize_layout/options.h index 2b1d8fdb5f..d422466852 100644 --- a/examples/03_visualize_layout/options.h +++ b/examples/03_visualize_layout/options.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/examples/03_visualize_layout/register_layout.cu b/examples/03_visualize_layout/register_layout.cu index 060abe353b..d20c893a16 100644 --- a/examples/03_visualize_layout/register_layout.cu +++ b/examples/03_visualize_layout/register_layout.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -64,15 +64,15 @@ void RegisterLayouts(std::map // All Ampere/Turing H/Integer matrix multiply tensor core kernels uses the same swizzling // layout implementation with different templates. // - // BMMA 88128 Interleaved-256 - // BMMA 168256 Interleaved-256 + // mma.sync.aligned.m8n8k128.s32.b1.b1.s32 Interleaved-256 + // mma.sync.aligned.m16n8k256.s32.b1.b1.s32 Interleaved-256 {"TensorOpMultiplicand<1,256>", new VisualizeLayout>}, - // BMMA 88128 TN kblock512 - // BMMA 168256 TN kblock512 + // mma.sync.aligned.m8n8k128.s32.b1.b1.s32 TN kblock512 + // mma.sync.aligned.m16n8k256.s32.b1.b1.s32 TN kblock512 {"TensorOpMultiplicand<1,512>", new VisualizeLayout>}, - // BMMA 168256 TN kblock1024 + // mma.sync.aligned.m16n8k256.s32.b1.b1.s32 TN kblock1024 {"TensorOpMultiplicand<1,1024>", new VisualizeLayout>}, // Integer matrix multiply.int4 8832 Interleaved-64 diff --git a/examples/03_visualize_layout/register_layout.h b/examples/03_visualize_layout/register_layout.h index b473279af9..0375f3251c 100644 --- a/examples/03_visualize_layout/register_layout.h +++ b/examples/03_visualize_layout/register_layout.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/examples/03_visualize_layout/visualize_layout.cpp b/examples/03_visualize_layout/visualize_layout.cpp index 81be32901d..1edf830d8e 100644 --- a/examples/03_visualize_layout/visualize_layout.cpp +++ b/examples/03_visualize_layout/visualize_layout.cpp @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -95,7 +95,7 @@ void print_usage(std::ostream &out) { "--extent=16,16 --vectorize=2 --output-shape=16,4\n" << "$ 03_visualize_layout \"VoltaTensorOpMultiplicandCrosswise<16,32>\" " "--extent=32,64 --vectorize=4 --output-shape=64,4\n" - << "$ 03_visualize_layout \"VotlaTensorOpMultiplicandCongruous<16>\" " + << "$ 03_visualize_layout \"VoltaTensorOpMultiplicandCongruous<16>\" " "--extent=64,32 --vectorize=8 --output-shape=64,4\n"; out << std::endl; diff --git a/examples/03_visualize_layout/visualize_layout.h b/examples/03_visualize_layout/visualize_layout.h index 7c9a1bc2f5..f070bad265 100644 --- a/examples/03_visualize_layout/visualize_layout.h +++ b/examples/03_visualize_layout/visualize_layout.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -260,7 +260,7 @@ class VisualizeLayout : public VisualizeLayoutBase { if (options.vectorize <= 2) return std::make_pair(false, -1); // Boundary check. - if (i > elements.size() || (i + options.vectorize - 1) > elements.size()) + if (i > int(elements.size()) || (i + options.vectorize - 1) > int(elements.size())) return std::make_pair(false, -1); // Check if either all elements are valid or invalid. diff --git a/examples/04_tile_iterator/CMakeLists.txt b/examples/04_tile_iterator/CMakeLists.txt index 52e1d665b6..55482729bd 100644 --- a/examples/04_tile_iterator/CMakeLists.txt +++ b/examples/04_tile_iterator/CMakeLists.txt @@ -1,5 +1,5 @@ -# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without diff --git a/examples/04_tile_iterator/tile_iterator.cu b/examples/04_tile_iterator/tile_iterator.cu index 886c17701a..b9441a562d 100644 --- a/examples/04_tile_iterator/tile_iterator.cu +++ b/examples/04_tile_iterator/tile_iterator.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -94,7 +94,7 @@ __global__ void copy( typename Iterator::Fragment fragment; - for(int i = 0; i < fragment.size(); ++i) { + for(size_t i = 0; i < fragment.size(); ++i) { fragment[i] = 0; } diff --git a/examples/05_batched_gemm/CMakeLists.txt b/examples/05_batched_gemm/CMakeLists.txt index f42e76b235..cd69403aa9 100644 --- a/examples/05_batched_gemm/CMakeLists.txt +++ b/examples/05_batched_gemm/CMakeLists.txt @@ -1,5 +1,5 @@ -# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without diff --git a/examples/05_batched_gemm/batched_gemm.cu b/examples/05_batched_gemm/batched_gemm.cu index 2ce552c79f..5fb7518f61 100644 --- a/examples/05_batched_gemm/batched_gemm.cu +++ b/examples/05_batched_gemm/batched_gemm.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -81,7 +81,7 @@ matrix A can be seen as --------------------------------------- batch 0 | batch 1 , where batch size is 2, M is 6 and K is 2 -The stride (batch_stride_B) between the first element of two batches is lda * k +The stride (batch_stride_A) between the first element of two batches is lda * k matrix B can be seen as ----------------------------- @@ -94,7 +94,7 @@ matrix B can be seen as (1,1,0) | (1,1,1) | (1,1,2) | ----------------------------- , where the batch size is 2, N is 3 and K is 2 -The stride (batch_stride_C) between the first element of two batches is k +The stride (batch_stride_B) between the first element of two batches is k */ @@ -207,15 +207,15 @@ cudaError_t strided_batched_gemm_nn_reference( cudaError_t result = cudaSuccess; - if (A.size() < lda * k * batch_count) { + if (A.size() < size_t(lda * k * batch_count)) { std::cout << "the size of A is too small" << std::endl; return cudaErrorInvalidValue; } - if (B.size() < ldb * n) { + if (B.size() < size_t(ldb * n)) { std::cout << "the size of B is too small" << std::endl; return cudaErrorInvalidValue; } - if (C.size() < ldc * n * batch_count) { + if (C.size() < size_t(ldc * n * batch_count)) { std::cout << "the size of C is too small" << std::endl; return cudaErrorInvalidValue; } diff --git a/examples/06_splitK_gemm/CMakeLists.txt b/examples/06_splitK_gemm/CMakeLists.txt index 04d7af8cf4..e0d11d0c0c 100644 --- a/examples/06_splitK_gemm/CMakeLists.txt +++ b/examples/06_splitK_gemm/CMakeLists.txt @@ -1,5 +1,5 @@ -# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without diff --git a/examples/06_splitK_gemm/splitk_gemm.cu b/examples/06_splitK_gemm/splitk_gemm.cu index 03e60788b1..1a559b8326 100644 --- a/examples/06_splitK_gemm/splitk_gemm.cu +++ b/examples/06_splitK_gemm/splitk_gemm.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -55,7 +55,7 @@ composed from lower level ones. Multiple thread-tiles (tile size each thread com to form warp-tiles (tile size each warp computes) and multiple warp tiles can be used to compute threadblock-tile (tile size computed by a threadblock). -In thie example, we split variable initialization into +In this example, we split variable initialization into 1. Setting up data properties : describes how matrices are laid out in the memory and how the kernel can view them (logical to physical mapping) 2. Setting up computation properties : describes how the above set matrices will be used to compute @@ -74,10 +74,10 @@ ElementAccumulator (float), ElementComputeEpilogue (float), ElementInputA (cutla ElementInputB (cutlass::half_t), ElementOutput (float). Communicating just the data type is not enough. As the data is laid out linearly in memory, we have to convey the layout of matrices. We do that by initializing template variable LayoutInputA to column major cutlass variable, LayoutInputB -to row major and LayoutOutput to row major. Next, we setup rules to comptue alpha * X + beta * C +to row major and LayoutOutput to row major. Next, we setup rules to compute alpha * X + beta * C which is called epilogue of the kernel. We initialize template variable EpilogueOp, which takes the -data type of output ElementOutput (int32_t), the number of elements per vector memory access (16), -data type of accumulator (int32_t) and data type of computation of linear combination (alpha * X + +data type of output ElementOutput (float), the number of elements per vector memory access (16), +data type of accumulator (float) and data type of computation of linear combination (alpha * X + beta * C). Now that we setup the properties of data, we have to setup properties of computation. @@ -85,7 +85,7 @@ Now that we setup the properties of data, we have to setup properties of computa Second, we create template variables of tile sizes for thread-block, warp and mma-op to 128x128x32, 64x64x4, 8x8x4 (MxNxK) respectively. When passed to instantiate CUTLASS GEMM kernel, it internally deduce the amount of threads needed per thread-block, amount of shared memory, storing data in -bank-conflict free manner, and ton of other variables required to compose, intialize and launch a +bank-conflict free manner, and ton of other variables required to compose, initialize and launch a high performance GEMM kernel. This is the beauty of CUTLASS, it relieves developer from understanding and coding complicated hardware optimizations which can easily go wrong. @@ -95,7 +95,7 @@ is done which threadblock launched on an SM, CUDA SM architecture of GPU you wan These are all put together to create a template variable which describes CUTLASS GEMM kernel using cutlass::gemm::device::GemmSplitKParallel template. -The next step is to intialize physical data, instantiate and initialize CUTLASS kernel and run it. +The next step is to initialize physical data, instantiate and initialize CUTLASS kernel and run it. We use CUTLASS utilities to initialize, fill, compare matrices as they are simple and doesn't come in the way of learning CUTLASS. @@ -103,7 +103,7 @@ Once all the matrices are initialized and filled with data, create arguments tup kernel which takes problem size (M = 5120, N = 4096 and K = 4096), matrices, alpha, beta and the important one, split k-dimension factor. Along with that, we query CUTLASS if any scratch-space memory required by the kernel we instantiated. If yes, we create it and pass it along with other -arguments created to intialize CUTLASS kernel then, the kernel is launched. +arguments created to initialize CUTLASS kernel then, the kernel is launched. In this example, we later on launch a reference gemm kernel (from CUTLASS utilities) to compare if the output from CUTLASS kernel is same as reference GEMM kernel. @@ -149,9 +149,6 @@ using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 32>; // <- warp tile M = // This code section describes the size of MMA op using ShapeMMAOp = cutlass::gemm::GemmShape<8, 8, 4>; // <- MMA Op tile M = 8, N = 8, K = 4 -// This code section describes how threadblocks are scheduled on GPU -using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? - // This code section describes ? using EpilogueOp = cutlass::epilogue::thread::LinearCombination< ElementOutput, // <- data type of output matrix diff --git a/examples/07_volta_tensorop_gemm/CMakeLists.txt b/examples/07_volta_tensorop_gemm/CMakeLists.txt index c53367ac5a..2503cd3d43 100644 --- a/examples/07_volta_tensorop_gemm/CMakeLists.txt +++ b/examples/07_volta_tensorop_gemm/CMakeLists.txt @@ -1,5 +1,5 @@ -# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without diff --git a/examples/07_volta_tensorop_gemm/volta_tensorop_gemm.cu b/examples/07_volta_tensorop_gemm/volta_tensorop_gemm.cu index eecd86cd45..23c2d9f45f 100644 --- a/examples/07_volta_tensorop_gemm/volta_tensorop_gemm.cu +++ b/examples/07_volta_tensorop_gemm/volta_tensorop_gemm.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -75,7 +75,7 @@ Now that we setup the properties of data, we have to setup properties of computa Second, we create template variables of tile sizes for thread-block, warp and mma-op to 128x128x32, 64x64x32, 8x8x4 (MxNxK) respectively. When passed to instantiate CUTLASS GEMM kernel, it internally deduce the amount of threads needed per thread-block, amount of shared memory, storing data in -bank-conflict free manner, and ton of other variables required to compose, intialize and launch a +bank-conflict free manner, and ton of other variables required to compose, initialize and launch a high performance GEMM kernel. This is the beauty of CUTLASS, it relieves developer from understanding and coding complicated hardware optimizations which can easily go wrong. @@ -107,7 +107,7 @@ is done which threadblock launched on an SM, CUDA SM architecture of GPU you wan These are all put together to create a template variable which describes CUTLASS GEMM kernel using cutlass::gemm::device::Gemm template. -The next step is to intialize physical data, instantiate and initialize CUTLASS kernel and run it. +The next step is to initialize physical data, instantiate and initialize CUTLASS kernel and run it. We use CUTLASS utilities to initialize, fill, compare matrices as they are simple and doesn't come in the way of learning CUTLASS. @@ -115,7 +115,7 @@ Once all the matrices are initialized and filled with data, create arguments tup kernel which takes problem size (M = 5120, N = 4096 and K = 4096), matrices, alpha, beta and the important one, split k-dimension factor. Along with that, we query CUTLASS if any scratch-space memory required by the kernel we instantiated. If yes, we create it and pass it along with other -arguments created to intialize CUTLASS kernel then, the kernel is launched. +arguments created to initialize CUTLASS kernel then, the kernel is launched. In this example, we later on launch a reference gemm kernel (from CUTLASS utilities) to compare if the output from CUTLASS kernel is same as reference GEMM kernel. @@ -162,7 +162,7 @@ using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 32>; // <- warp tile M = using ShapeMMAOp = cutlass::gemm::GemmShape<8, 8, 4>; // <- MMA Op tile M = 8, N = 8, K = 4 // This code section describes how threadblocks are scheduled on GPU -using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? +using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // This code section describes ? using EpilogueOp = cutlass::epilogue::thread::LinearCombination< diff --git a/examples/08_turing_tensorop_gemm/CMakeLists.txt b/examples/08_turing_tensorop_gemm/CMakeLists.txt index c84bc6d49b..2e0a54817d 100644 --- a/examples/08_turing_tensorop_gemm/CMakeLists.txt +++ b/examples/08_turing_tensorop_gemm/CMakeLists.txt @@ -1,5 +1,5 @@ -# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without diff --git a/examples/08_turing_tensorop_gemm/turing_tensorop_gemm.cu b/examples/08_turing_tensorop_gemm/turing_tensorop_gemm.cu index 1869c41f00..34f682deb0 100644 --- a/examples/08_turing_tensorop_gemm/turing_tensorop_gemm.cu +++ b/examples/08_turing_tensorop_gemm/turing_tensorop_gemm.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -74,7 +74,7 @@ Now that we setup the properties of data, we have to setup properties of computa Second, we create template variables of tile sizes for thread-block, warp and mma-op to 128x256x64, 64x64x16, 8x8x16 (MxNxK) respectively. When passed to instantiate CUTLASS GEMM kernel, it internally deduce the amount of threads needed per thread-block, amount of shared memory, storing data in -bank-conflict free manner, and ton of other variables required to compose, intialize and launch a +bank-conflict free manner, and ton of other variables required to compose, initialize and launch a high performance GEMM kernel. This is the beauty of CUTLASS, it relieves developer from understanding and coding complicated hardware optimizations which can easily go wrong. @@ -106,7 +106,7 @@ is done which threadblock launched on an SM, CUDA SM architecture of GPU you wan These are all put together to create a template variable which describes CUTLASS GEMM kernel using cutlass::gemm::device::Gemm template. -The next step is to intialize physical data, instantiate and initialize CUTLASS kernel and run it. +The next step is to initialize physical data, instantiate and initialize CUTLASS kernel and run it. We use CUTLASS utilities to initialize, fill, compare matrices as they are simple and doesn't come in the way of learning CUTLASS. @@ -114,7 +114,7 @@ Once all the matrices are initialized and filled with data, create arguments tup kernel which takes problem size (M = 5120, N = 4096 and K = 4096), matrices, alpha, beta and the important one, split k-dimension factor. Along with that, we query CUTLASS if any scratch-space memory required by the kernel we instantiated. If yes, we create it and pass it along with other -arguments created to intialize CUTLASS kernel then, the kernel is launched. +arguments created to initialize CUTLASS kernel then, the kernel is launched. In this example, we later on launch a reference gemm kernel (from CUTLASS utilities) to compare if the output from CUTLASS kernel is same as reference GEMM kernel. @@ -140,8 +140,8 @@ using ElementInputA = int8_t; // <- data type of elements using ElementInputB = int8_t; // <- data type of elements in input matrix B using ElementOutput = int32_t; // <- data type of elements in output matrix D -// The code section below describes matrix layout of input and output matrices. Column Major for -// Matrix A, Row Major for Matrix B and Row Major for Matrix C +// The code section below describes matrix layout of input and output matrices. Row Major for +// Matrix A, Column Major for Matrix B and Row Major for Matrix C using LayoutInputA = cutlass::layout::RowMajor; using LayoutInputB = cutlass::layout::ColumnMajor; using LayoutOutput = cutlass::layout::RowMajor; @@ -161,7 +161,7 @@ using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 64>; // <- warp tile M = using ShapeMMAOp = cutlass::gemm::GemmShape<8, 8, 16>; // <- MMA Op tile M = 8, N = 8, K = 16 // This code section describes how threadblocks are scheduled on GPU -using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? +using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // This code section describes the epilogue part of the kernel using EpilogueOp = cutlass::epilogue::thread::LinearCombination< @@ -355,4 +355,3 @@ int main() { return run(); } - diff --git a/examples/09_turing_tensorop_conv2dfprop/CMakeLists.txt b/examples/09_turing_tensorop_conv2dfprop/CMakeLists.txt index 09057a28d5..673064edc4 100644 --- a/examples/09_turing_tensorop_conv2dfprop/CMakeLists.txt +++ b/examples/09_turing_tensorop_conv2dfprop/CMakeLists.txt @@ -1,5 +1,5 @@ -# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without diff --git a/examples/09_turing_tensorop_conv2dfprop/turing_tensorop_conv2dfprop.cu b/examples/09_turing_tensorop_conv2dfprop/turing_tensorop_conv2dfprop.cu index bd74ce12da..adca0568bd 100644 --- a/examples/09_turing_tensorop_conv2dfprop/turing_tensorop_conv2dfprop.cu +++ b/examples/09_turing_tensorop_conv2dfprop/turing_tensorop_conv2dfprop.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -76,7 +76,7 @@ Now that we setup the properties of data, we have to setup properties of computa Second, we create template variables of tile sizes for thread-block, warp and mma-op to 128x128x128, 64x64x128, 8x8x32 (MxNxK) respectively. When passed to instantiate CUTLASS Implicit GEMM kernel, it internally deduces the amount of threads needed per thread-block, amount of shared memory, storing -data in bank-conflict free manner, and ton of other variables required to compose, intialize and +data in bank-conflict free manner, and ton of other variables required to compose, initialize and launch a high performance Implicit GEMM kernel. This is the beauty of CUTLASS, it relieves developer from understanding and coding complicated hardware optimizations which can easily go wrong. @@ -108,7 +108,7 @@ is done which threadblock launched on an SM, CUDA SM architecture of GPU you wan These are all put together to create a template variable which describes CUTLASS Implicit GEMM kernel using cutlass::conv::device::ImplicitGemm template. -The next step is to intialize physical data, instantiate and initialize CUTLASS kernel and run it. +The next step is to initialize physical data, instantiate and initialize CUTLASS kernel and run it. We use CUTLASS utilities to initialize, fill, compare tensors as they are simple and doesn't come in the way of learning CUTLASS. @@ -117,7 +117,7 @@ kernel which takes problem size (N = 1, H = 64, W = 64, C = 128), filter size (K R = 3, S = 3, C = 128 ), padding, strides, dilation, tensors, alpha, beta and the important one, split k-dimension factor. Along with that, we query CUTLASS if any scratch-space memory required by the kernel we instantiated. If yes, we create it and pass it along with other -arguments created to intialize CUTLASS kernel then, the kernel is launched. +arguments created to initialize CUTLASS kernel then, the kernel is launched. In this example, we later on launch a reference convolution kernel (from CUTLASS utilities) to compare if the output from CUTLASS kernel is same as the reference implicit GEMM kernel. @@ -143,7 +143,6 @@ compare if the output from CUTLASS kernel is same as the reference implicit GEMM #include "cutlass/util/tensor_view_io.h" #include "helper.h" - // The code section below describes datatype for input, output tensors and computation between // elements using ElementAccumulator = int32_t; // Data type of accumulator @@ -555,6 +554,7 @@ Result profile_convolution(Options const &options) { LayoutOutput, ElementComputeEpilogue, ElementAccumulator, + ElementOutput, cutlass::NumericConverterClamp >( problem_size, @@ -674,7 +674,6 @@ Result profile_convolution(Options const &options) { return result; } - ///////////////////////////////////////////////////////////////////////////////////////////////// int main(int argc, char const **args) { @@ -761,11 +760,7 @@ int main(int argc, char const **args) { Result::print_header(std::cout, options) << std::endl; result.print(std::cout, 1, options) << std::endl; } - return 0; } ///////////////////////////////////////////////////////////////////////////////////////////////// - - - diff --git a/examples/10_planar_complex/CMakeLists.txt b/examples/10_planar_complex/CMakeLists.txt index eaf41fde7c..ebe78d6b61 100644 --- a/examples/10_planar_complex/CMakeLists.txt +++ b/examples/10_planar_complex/CMakeLists.txt @@ -1,5 +1,5 @@ -# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without @@ -27,7 +27,10 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - +# +# This example depends on the CUTLASS Library +# +if (CUTLASS_ENABLE_LIBRARY) # Planar Complex GEMM example cutlass_example_add_executable( @@ -35,15 +38,12 @@ cutlass_example_add_executable( planar_complex.cu ) - -# -# This example depends on the CUTLASS Library -# - target_link_libraries( 10_planar_complex PRIVATE cutlass_lib cutlass_tools_util_includes + cuda ) +endif() diff --git a/examples/10_planar_complex/planar_complex.cu b/examples/10_planar_complex/planar_complex.cu index 9a9dc88888..2d7ee95eec 100644 --- a/examples/10_planar_complex/planar_complex.cu +++ b/examples/10_planar_complex/planar_complex.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/examples/11_planar_complex_array/CMakeLists.txt b/examples/11_planar_complex_array/CMakeLists.txt index b5ad07cf7a..0e3fc9e987 100644 --- a/examples/11_planar_complex_array/CMakeLists.txt +++ b/examples/11_planar_complex_array/CMakeLists.txt @@ -1,5 +1,5 @@ -# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without @@ -27,7 +27,10 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - +# +# This example depends on the CUTLASS Library +# +if (CUTLASS_ENABLE_LIBRARY) # Planar Complex Array GEMM example cutlass_example_add_executable( @@ -35,15 +38,12 @@ cutlass_example_add_executable( planar_complex_array.cu ) - -# -# This example depends on the CUTLASS Library -# - target_link_libraries( 11_planar_complex_array PRIVATE cutlass_lib cutlass_tools_util_includes + cuda ) +endif() diff --git a/examples/11_planar_complex_array/planar_complex_array.cu b/examples/11_planar_complex_array/planar_complex_array.cu index 272390f26b..0df6e57284 100644 --- a/examples/11_planar_complex_array/planar_complex_array.cu +++ b/examples/11_planar_complex_array/planar_complex_array.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/examples/12_gemm_bias_relu/CMakeLists.txt b/examples/12_gemm_bias_relu/CMakeLists.txt index 3b681b31fa..e3e428dfcb 100644 --- a/examples/12_gemm_bias_relu/CMakeLists.txt +++ b/examples/12_gemm_bias_relu/CMakeLists.txt @@ -1,5 +1,5 @@ -# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without diff --git a/examples/12_gemm_bias_relu/gemm_bias_relu.cu b/examples/12_gemm_bias_relu/gemm_bias_relu.cu index 62eb294028..bca8e0ac74 100644 --- a/examples/12_gemm_bias_relu/gemm_bias_relu.cu +++ b/examples/12_gemm_bias_relu/gemm_bias_relu.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -54,12 +54,11 @@ using ElementInputA = cutlass::half_t; // <- data type of elements using ElementInputB = cutlass::half_t; // <- data type of elements in input matrix B using ElementOutput = float; // <- data type of elements in output matrix D -// The code section below describes matrix layout of input and output matrices. -// Column Major for Matrix A, B and C. - // Note that if the output is column major, the bias has to be per row. i.e. every row has different bias. // If the output is row major, the bias has to be per column, i.e. every column has different bias. // Below list some other notices: +// +// Note this example only works for ColumnMajor output because // 1) we only have row major epilogue. // 2) we swap A and B if the output is column major then we can still use the // row major epilogue. @@ -82,10 +81,10 @@ using ShapeMMAThreadBlock = // This code section describes tile size a warp will compute using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 32>; // <- warp tile M = 64, N = 64, K = 32 // This code section describes the size of MMA op -using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 8>; // <- MMA Op tile M = 8, N = 8, K = 4 +using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 8>; // <- MMA Op tile M = 16, N = 8, K = 8 // This code section describes how threadblocks are scheduled on GPU -using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? +using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // Define the epilogue operation as LinearCombinationRelu. This is approximately equal to // diff --git a/examples/13_two_tensor_op_fusion/CMakeLists.txt b/examples/13_two_tensor_op_fusion/CMakeLists.txt index 04d55bbec5..6819a9766e 100644 --- a/examples/13_two_tensor_op_fusion/CMakeLists.txt +++ b/examples/13_two_tensor_op_fusion/CMakeLists.txt @@ -1,5 +1,5 @@ -# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without @@ -64,6 +64,7 @@ endforeach() foreach(FUSION_GEMM_EXAMPLE fused_two_gemms_f16_sm75_rf fused_two_gemms_f16_sm75_shmem + fused_two_gemms_grouped_f16_sm80_rf fused_two_gemms_f16_sm80_rf fused_two_gemms_f16_sm80_shmem fused_two_gemms_s8_sm75_rf @@ -79,4 +80,3 @@ foreach(FUSION_GEMM_EXAMPLE add_dependencies(13_fused_two_gemms 13_${FUSION_GEMM_EXAMPLE}) endforeach() - diff --git a/examples/13_two_tensor_op_fusion/README.md b/examples/13_two_tensor_op_fusion/README.md index 134644a063..4b9cb6d128 100644 --- a/examples/13_two_tensor_op_fusion/README.md +++ b/examples/13_two_tensor_op_fusion/README.md @@ -1,11 +1,11 @@ # Introduction -This example shows fusing two back-to-back GEMMs/Convolutions into one kernel. +This example shows fusing two back-to-back GEMMs/Convolutions into one kernel.

-When running two unfused GEMM/Conv operations, each operation loads one input -activation matrix, one weight matrix (or filter matrix) from the memory and then +When running two unfused GEMM/Conv operations, each operation loads one input +activation matrix, one weight matrix (or filter matrix) from the memory and then stores the result activation matrix back to the memory. When the two GEMM/Conv operations are fused together, the mainloops of the two @@ -27,10 +27,10 @@ In order to run two GEMM/Convs in a single kernel, the example requires the same threadblocks are used across 2 GEMMs/Convs. This also ensures the same threadblock tile M across 2 GEMMs/Convs. -In order to reuse the output accumulator (stored in register-file) of the 1st GEMM as the +In order to reuse the output accumulator (stored in register-file) of the 1st GEMM as the input activation, the example enforces the following two constraints: -- thread_block_tile_N = problem_N +- thread_block_tile_N = problem_N

@@ -39,7 +39,7 @@ addition to its own input activation tile. Therefore the input activation tile o 2nd GEMM/Conv only depends on the output activation tile of the 1st GEMM/Conv, and the operation can be fully block-resident. -- warp_tile_N = thread_block_tile_N +- warp_tile_N = thread_block_tile_N

@@ -82,11 +82,11 @@ threadblock. Typically this requires the 2nd Convolution uses 1x1 filter without - `./examples/13_two_tensor_op_fusion/13_fused_two_gemms_s8_sm75_shmem` - `./examples/13_two_tensor_op_fusion/13_fused_two_gemms_s8_sm80_rf` - `./examples/13_two_tensor_op_fusion/13_fused_two_gemms_s8_sm80_shmem` - + # Copyright -Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. SPDX-License-Identifier: BSD-3-Clause ``` diff --git a/examples/13_two_tensor_op_fusion/b2b_conv2d_run.h b/examples/13_two_tensor_op_fusion/b2b_conv2d_run.h index b050906317..03ae75c62c 100644 --- a/examples/13_two_tensor_op_fusion/b2b_conv2d_run.h +++ b/examples/13_two_tensor_op_fusion/b2b_conv2d_run.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/examples/13_two_tensor_op_fusion/b2b_gemm_run.h b/examples/13_two_tensor_op_fusion/b2b_gemm_run.h index 6cc4ffd9f4..8e828d1f17 100644 --- a/examples/13_two_tensor_op_fusion/b2b_gemm_run.h +++ b/examples/13_two_tensor_op_fusion/b2b_gemm_run.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -42,6 +42,7 @@ #include "cutlass/util/reference/host/tensor_compare.h" #include "cutlass/util/reference/host/tensor_norm.h" #include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/gemm_complex.h" #include "cutlass/util/reference/device/tensor_relu.h" #include "reference/device/tensor_scale_bias.h" @@ -77,9 +78,9 @@ struct B2bNonFusedGemmRun // B2bNonFusedGemmRun( - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform, uint64_t seed_ = 2080 ): @@ -88,7 +89,7 @@ struct B2bNonFusedGemmRun /// Helper to initialize a tensor view template bool initialize_tensor( - cutlass::TensorView view, + cutlass::TensorView view, cutlass::Distribution::Kind dist_kind, uint64_t seed) { @@ -96,7 +97,7 @@ struct B2bNonFusedGemmRun cutlass::reference::host::TensorFillRandomUniform( view, seed, 2, -2, 0); - } + } else if (dist_kind == cutlass::Distribution::Identity) { cutlass::reference::host::TensorFillIdentity(view); @@ -129,62 +130,62 @@ struct B2bNonFusedGemmRun /// Executes one test bool run( - cutlass::gemm::GemmCoord problem_size_0, - cutlass::gemm::GemmCoord problem_size_1, - ElementCompute alpha0 = ElementCompute(1), + cutlass::gemm::GemmCoord problem_size_0, + cutlass::gemm::GemmCoord problem_size_1, + ElementCompute alpha0 = ElementCompute(1), ElementCompute beta0 = ElementCompute(0), - ElementCompute alpha1 = ElementCompute(1), + ElementCompute alpha1 = ElementCompute(1), ElementCompute beta1 = ElementCompute(0), bool relu = true, int warm_ups = 1, int runs = 100) { - + // // Allocate the GEMM workspace // cutlass::HostTensor< - typename Gemm0::ElementA, + typename Gemm0::ElementA, typename Gemm0::LayoutA> tensor_A0(problem_size_0.mk()); cutlass::HostTensor< - typename Gemm0::ElementB, + typename Gemm0::ElementB, typename Gemm0::LayoutB> tensor_B0(problem_size_0.kn()); cutlass::HostTensor< - typename Gemm0::ElementC, + typename Gemm0::ElementC, typename Gemm0::LayoutC> tensor_C0(problem_size_0.mn()); cutlass::HostTensor< - ElementCompute, + ElementCompute, typename Gemm0::LayoutC> tensor_Bias0({1, problem_size_0.n()}); cutlass::HostTensor< - typename Gemm0::ElementC, + typename Gemm0::ElementC, typename Gemm0::LayoutC> tensor_D0(problem_size_0.mn()); cutlass::HostTensor< - typename Gemm0::ElementC, + typename Gemm0::ElementC, typename Gemm0::LayoutC> reference_D0(problem_size_0.mn()); cutlass::HostTensor< - typename Gemm1::ElementB, + typename Gemm1::ElementB, typename Gemm1::LayoutB> tensor_B1(problem_size_1.kn()); cutlass::HostTensor< - typename Gemm1::ElementC, + typename Gemm1::ElementC, typename Gemm1::LayoutC> tensor_C1(problem_size_1.mn()); cutlass::HostTensor< - ElementCompute, + ElementCompute, typename Gemm1::LayoutC> tensor_Bias1({1, problem_size_1.n()}); cutlass::HostTensor< - typename Gemm1::ElementC, + typename Gemm1::ElementC, typename Gemm1::LayoutC> tensor_D1(problem_size_1.mn()); cutlass::HostTensor< - typename Gemm1::ElementC, + typename Gemm1::ElementC, typename Gemm1::LayoutC> reference_D1(problem_size_1.mn()); @@ -270,13 +271,13 @@ struct B2bNonFusedGemmRun for(int i = 0; i < runs; i++) { status = gemm_op_0(); - + CUTLASS_CHECK(status); } cudaEventRecord(stop1); for(int i = 0; i < runs; i++) { status = gemm_op_1(); - + CUTLASS_CHECK(status); } @@ -312,32 +313,32 @@ struct B2bNonFusedGemmRun reference_gemm_0( problem_size_0, - alpha0, - tensor_A0.device_ref(), - tensor_B0.device_ref(), - beta0, + alpha0, + tensor_A0.device_ref(), + tensor_B0.device_ref(), + beta0, {tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)}, reference_D0.device_ref() ); if(relu) { - cutlass::reference::device::TensorReLu(reference_D0.device_view()); + cutlass::reference::device::TensorReLu(reference_D0.device_view()); } reference_gemm_1( problem_size_1, - alpha1, - reference_D0.device_ref(), - tensor_B1.device_ref(), + alpha1, + reference_D0.device_ref(), + tensor_B1.device_ref(), beta1, {tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)}, reference_D1.device_ref() ); - + if(relu) { - cutlass::reference::device::TensorReLu(reference_D1.device_view()); + cutlass::reference::device::TensorReLu(reference_D1.device_view()); } - + // Wait for kernels to finish cudaDeviceSynchronize(); reference_D0.sync_host(); @@ -349,7 +350,7 @@ struct B2bNonFusedGemmRun CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0); bool passed = cutlass::reference::host::TensorEquals( - reference_D1.host_view(), + reference_D1.host_view(), tensor_D1.host_view()); CHECK_TRUE(passed); @@ -362,7 +363,7 @@ struct B2bNonFusedGemmRun std::ofstream file(fname.str()); - file + file << "A0 =\n" << tensor_A0.host_view() << "\nB0 =\n" << tensor_B0.host_view() << "\nC0 =\n" << tensor_C0.host_view() @@ -399,9 +400,9 @@ struct B2bFusedGemmRun // B2bFusedGemmRun( - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform, uint64_t seed_ = 2080 @@ -412,7 +413,7 @@ struct B2bFusedGemmRun /// Helper to initialize a tensor view template bool initialize_tensor( - cutlass::TensorView view, + cutlass::TensorView view, cutlass::Distribution::Kind dist_kind, uint64_t seed) { @@ -420,11 +421,11 @@ struct B2bFusedGemmRun cutlass::reference::host::TensorFillRandomUniform( view, seed, 2, -2, 0); - } + } else if (dist_kind == cutlass::Distribution::Identity) { cutlass::reference::host::TensorFillIdentity(view); - } + } else if (dist_kind == cutlass::Distribution::Gaussian) { cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); @@ -453,70 +454,90 @@ struct B2bFusedGemmRun /// Executes one test bool run( - cutlass::gemm::GemmCoord problem_size_0, - cutlass::gemm::GemmCoord problem_size_1, - ElementCompute alpha0 = ElementCompute(1), + cutlass::gemm::GemmCoord problem_size_0, + cutlass::gemm::GemmCoord problem_size_1, + ElementCompute alpha0 = ElementCompute(1), ElementCompute beta0 = ElementCompute(0), - ElementCompute alpha1 = ElementCompute(1), + ElementCompute alpha1 = ElementCompute(1), ElementCompute beta1 = ElementCompute(0), + cutlass::gemm::GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm, + + // batch_count is used as split-k when mode is kGemm according + // to the GemmUniversal interface + + int batch_count = 1, + int64_t batch_stride_A0 = 0, + int64_t batch_stride_B0 = 0, + int64_t batch_stride_C0 = 0, + int64_t batch_stride_B1 = 0, + int64_t batch_stride_C1 = 0, + int64_t batch_stride_D1 = 0, + int64_t batch_stride_Bias0 = 0, + int64_t batch_stride_Scale0 = 0, bool relu = true, int warm_ups = 1, int runs = 100) { - + // // Allocate the GEMM workspace // + cutlass::gemm::GemmCoord CoordA0(problem_size_0.m(), problem_size_0.n(), batch_count * problem_size_0.k()); + cutlass::gemm::GemmCoord CoordB0(problem_size_0.m(), problem_size_0.n(), batch_count * problem_size_0.k()); + cutlass::gemm::GemmCoord CoordC0(problem_size_0.m(), batch_count * problem_size_0.n(), problem_size_0.k()); + cutlass::gemm::GemmCoord CoordB1(problem_size_1.m(), problem_size_1.n(), batch_count * problem_size_1.k()); + cutlass::gemm::GemmCoord CoordC1(problem_size_1.m(), batch_count * problem_size_1.n(), problem_size_1.k()); + cutlass::HostTensor< - typename B2bGemm::ElementA, - typename B2bGemm::LayoutA> tensor_A0(problem_size_0.mk()); + typename B2bGemm::ElementA, + typename B2bGemm::LayoutA> tensor_A0(CoordA0.mk()); cutlass::HostTensor< - typename B2bGemm::ElementB, - typename B2bGemm::LayoutB> tensor_B0(problem_size_0.kn()); + typename B2bGemm::ElementB, + typename B2bGemm::LayoutB> tensor_B0(CoordB0.kn()); cutlass::HostTensor< - typename B2bGemm::ElementC, - typename B2bGemm::LayoutC> tensor_C0(problem_size_0.mn()); + typename B2bGemm::ElementC, + typename B2bGemm::LayoutC> tensor_C0(CoordC0.mn()); cutlass::HostTensor< - typename B2bGemm::ElementScaleBias, + typename B2bGemm::ElementScaleBias, typename B2bGemm::LayoutScaleBias> tensor_Scale0; if(alpha0 == ElementCompute(0)) //per-channel scale - tensor_Scale0.resize({1, problem_size_0.n()}); + tensor_Scale0.resize({1, batch_count * problem_size_0.n()}); cutlass::HostTensor< - typename B2bGemm::ElementScaleBias, - typename B2bGemm::LayoutScaleBias> tensor_Bias0({1, problem_size_0.n()}); + typename B2bGemm::ElementScaleBias, + typename B2bGemm::LayoutScaleBias> tensor_Bias0({1, batch_count * problem_size_0.n()}); cutlass::HostTensor< - ElementAccumulator, - typename B2bGemm::LayoutC> reference_Z0(problem_size_0.mn()); + ElementAccumulator, + typename B2bGemm::LayoutC> reference_Z0(CoordC0.mn()); cutlass::HostTensor< - typename B2bGemm::ElementC, - typename B2bGemm::LayoutC> reference_D0(problem_size_0.mn()); + typename B2bGemm::ElementC, + typename B2bGemm::LayoutC> reference_D0(CoordC0.mn()); cutlass::HostTensor< - typename B2bGemm::ElementB, - typename B2bGemm::LayoutB> tensor_B1(problem_size_1.kn()); + typename B2bGemm::ElementB, + typename B2bGemm::LayoutB> tensor_B1(CoordB1.kn()); cutlass::HostTensor< - typename B2bGemm::ElementC, - typename B2bGemm::LayoutC> tensor_C1(problem_size_1.mn()); + typename B2bGemm::ElementC, + typename B2bGemm::LayoutC> tensor_C1(CoordC1.mn()); cutlass::HostTensor< - ElementCompute, - typename B2bGemm::LayoutScaleBias> tensor_Bias1({1, problem_size_1.n()}); + typename B2bGemm::ElementC, + typename B2bGemm::LayoutScaleBias> tensor_Bias1({1, batch_count * problem_size_1.n()}); cutlass::HostTensor< - typename B2bGemm::ElementC, - typename B2bGemm::LayoutC> tensor_D1(problem_size_1.mn()); + typename B2bGemm::ElementC, + typename B2bGemm::LayoutC> tensor_D1(CoordC1.mn()); cutlass::HostTensor< - typename B2bGemm::ElementC, - typename B2bGemm::LayoutC> reference_D1(problem_size_1.mn()); + typename B2bGemm::ElementC, + typename B2bGemm::LayoutC> reference_D1(CoordC1.mn()); CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019)); @@ -554,6 +575,7 @@ struct B2bFusedGemmRun // typename B2bGemm::Arguments arguments{ + mode, problem_size_0, problem_size_1, tensor_A0.device_ref(), @@ -564,8 +586,16 @@ struct B2bFusedGemmRun tensor_B1.device_ref(), {tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)}, tensor_D1.device_ref(), + batch_stride_A0, + batch_stride_B0, + batch_stride_B1, + batch_stride_C1, + batch_stride_D1, + batch_stride_Bias0, + batch_stride_Scale0, {alpha0, beta0}, {alpha1, beta1}, + batch_count, }; B2bGemm b2b_gemm_op; @@ -618,32 +648,31 @@ struct B2bFusedGemmRun // Verify // - cutlass::reference::device::Gemm< - typename B2bGemm::ElementA, typename B2bGemm::LayoutA, - typename B2bGemm::ElementB, typename B2bGemm::LayoutB, - ElementAccumulator, typename B2bGemm::LayoutC, - ElementAccumulator, ElementAccumulator> - reference_gemm_0; - - cutlass::reference::device::Gemm< - typename B2bGemm::ElementA, typename B2bGemm::LayoutA, - typename B2bGemm::ElementB, typename B2bGemm::LayoutB, - typename B2bGemm::ElementC, typename B2bGemm::LayoutC, ElementCompute, - ElementAccumulator, typename B2bGemm::Operator> - reference_gemm_1; + cutlass::reference::device::GemmComplex< + typename B2bGemm::ElementA, typename B2bGemm::LayoutA, + typename B2bGemm::ElementB, typename B2bGemm::LayoutB, + ElementAccumulator, typename B2bGemm::LayoutC, + ElementAccumulator, ElementAccumulator + >( - reference_gemm_0( problem_size_0, ElementAccumulator(1), //intermediate alpha=1 - tensor_A0.device_ref(), - tensor_B0.device_ref(), + tensor_A0.device_ref(), + cutlass::ComplexTransform::kNone, + tensor_B0.device_ref(), + cutlass::ComplexTransform::kNone, ElementAccumulator(0), //beta = 0 reference_Z0.device_ref(), reference_Z0.device_ref(), - ElementAccumulator(0) + ElementAccumulator(0), + int(batch_count), + batch_stride_A0, + batch_stride_B0, + batch_stride_C0, + batch_stride_C0 ); - cutlass::reference::device::TensorScaleBiasGemm< + cutlass::reference::device::TensorScaleBiasGemmBatched< ElementAccumulator, typename B2bGemm::ElementC, typename B2bGemm::LayoutC, ElementCompute, typename B2bGemm::LayoutScaleBias > ( @@ -652,25 +681,45 @@ struct B2bFusedGemmRun reference_D0.device_ref(), alpha0, tensor_Scale0.device_ref(), - tensor_Bias0.device_ref() + tensor_Bias0.device_ref(), + int(batch_count), + batch_stride_C0, + batch_stride_C0, + batch_stride_Scale0, + batch_stride_Bias0 ); if(relu) { - cutlass::reference::device::TensorReLu(reference_D0.device_view()); + cutlass::reference::device::TensorReLu(reference_D0.device_view()); } - reference_gemm_1( + cutlass::reference::device::GemmComplex< + typename B2bGemm::ElementA, typename B2bGemm::LayoutA, + typename B2bGemm::ElementB, typename B2bGemm::LayoutB, + typename B2bGemm::ElementC, typename B2bGemm::LayoutC, + ElementCompute, ElementAccumulator + >( problem_size_1, - alpha1, - reference_D0.device_ref(), - tensor_B1.device_ref(), - beta1, + alpha1, //intermediate alpha=1 + reference_D0.device_ref(), + cutlass::ComplexTransform::kNone, + tensor_B1.device_ref(), + cutlass::ComplexTransform::kNone, + beta1, //beta = 0 {tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)}, - reference_D1.device_ref() + reference_D1.device_ref(), + ElementAccumulator(0), + int(batch_count), + batch_stride_C0, + batch_stride_B1, + batch_stride_C1, + batch_stride_D1 ); + if(relu) { - cutlass::reference::device::TensorReLu(reference_D1.device_view()); + cutlass::reference::device::TensorReLu(reference_D1.device_view()); } + cudaDeviceSynchronize(); reference_D0.sync_host(); reference_D1.sync_host(); @@ -680,7 +729,7 @@ struct B2bFusedGemmRun CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0); bool passed = cutlass::reference::host::TensorEquals( - reference_D1.host_view(), + reference_D1.host_view(), tensor_D1.host_view()); CHECK_TRUE(passed); @@ -694,7 +743,7 @@ struct B2bFusedGemmRun std::ofstream file(fname.str()); - file + file << "A0 =\n" << tensor_A0.host_view() << "\nB0 =\n" << tensor_B0.host_view() << "\nC0 =\n" << tensor_C0.host_view() diff --git a/examples/13_two_tensor_op_fusion/b2b_grouped_gemm_run.h b/examples/13_two_tensor_op_fusion/b2b_grouped_gemm_run.h new file mode 100644 index 0000000000..2206bac0e6 --- /dev/null +++ b/examples/13_two_tensor_op_fusion/b2b_grouped_gemm_run.h @@ -0,0 +1,450 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Containers for running grouped back-to-back GEMMs +*/ + +#pragma once + +#include +#include +#include + +#include "cutlass/util/device_memory.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_relu.h" + +#include "reference/device/tensor_scale_bias.h" +#include "helper.h" + +#define CHECK_GT(val1, val2) \ + if((val1) <= (val2)) \ + std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_GT failed\n"; +#define CHECK_TRUE(val) \ + if(!(val)) \ + std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_TRUE failed\n"; + +//////////////////////////////////////////////////////////////////////////////// + +template +struct B2bFusedGroupedGemmRun +{ + + using B2bGemm = B2bGemm_; + using ElementAccumulator = typename B2bGemm::ElementAccumulator; + using ElementCompute = typename B2bGemm::BaseKernel::Epilogue::OutputOp::ElementCompute; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + cutlass::Distribution::Kind init_Scale; + cutlass::Distribution::Kind init_Bias; + uint64_t seed; + + // + // Methods + // + + B2bFusedGroupedGemmRun( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), + init_Scale(init_Scale_), init_Bias(init_Bias_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, 1, -1, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else if (dist_kind == cutlass::Distribution::AllZeros) { + cutlass::reference::host::TensorFill(view, Element(0)); + } + else if (dist_kind == cutlass::Distribution::AllOnes) { + cutlass::reference::host::TensorFill(view, Element(1)); + } + else { + std::cerr << "Not implemented\n"; + return false; + } + + return true; + } + + /// Executes one test + bool run( + std::vector problem_sizes_0, + std::vector problem_sizes_1, + ElementCompute alpha0 = ElementCompute(1), + ElementCompute beta0 = ElementCompute(0), + ElementCompute alpha1 = ElementCompute(1), + ElementCompute beta1 = ElementCompute(0), + bool relu = true, + int warm_ups = 1, + int runs = 100) { + + using HostTensorA = cutlass::HostTensor; + using HostTensorB = cutlass::HostTensor; + using HostTensorC = cutlass::HostTensor; + using HostTensorScale = cutlass::HostTensor; + using HostTensorZ = cutlass::HostTensor; + using HostTensorBias = cutlass::HostTensor; + + int problem_count = (int)problem_sizes_0.size(); + + std::vector host_tensor_A0(problem_count); + std::vector host_tensor_B0(problem_count); + std::vector host_tensor_C0(problem_count); + std::vector host_tensor_Scale0(problem_count); + std::vector host_tensor_Bias0(problem_count); + std::vector host_tensor_B1(problem_count); + std::vector host_tensor_C1(problem_count); + std::vector host_tensor_Bias1(problem_count); + std::vector host_tensor_D1(problem_count); + std::vector host_tensor_Z(problem_count); + std::vector host_tensor_ref_D0(problem_count); + std::vector host_tensor_ref_D1(problem_count); + + std::vector ref_A0(problem_count); + std::vector ref_B0(problem_count); + std::vector ref_C0(problem_count); + std::vector ref_Scale0(problem_count); + std::vector ref_Bias0(problem_count); + std::vector ref_B1(problem_count); + std::vector ref_C1(problem_count); + std::vector ref_Bias1(problem_count); + std::vector ref_D1(problem_count); + std::vector ref_Z(problem_count); + std::vector ref_ref_D0(problem_count); + std::vector ref_ref_D1(problem_count); + + for (int i = 0; i < problem_count; ++i) { + // + // Allocate the GEMM workspace + // + + auto problem_size_0 = problem_sizes_0[i]; + auto problem_size_1 = problem_sizes_1[i]; + + host_tensor_A0.at(i) = HostTensorA(problem_size_0.mk()); + host_tensor_B0.at(i) = HostTensorB(problem_size_0.kn()); + host_tensor_C0.at(i) = HostTensorC(problem_size_0.mn()); + if (alpha0 == ElementCompute(0)) //per-channel scale + host_tensor_Scale0.at(i) = HostTensorScale(typename HostTensorZ::Layout::TensorCoord{1, problem_size_0.n()}); + host_tensor_Bias0.at(i) = HostTensorScale(typename HostTensorBias::Layout::TensorCoord{1, problem_size_0.n()}); + host_tensor_Z.at(i) = HostTensorZ(problem_size_0.mn()); + host_tensor_ref_D0.at(i) = HostTensorC(problem_size_0.mn()); + host_tensor_B1.at(i) = HostTensorB(problem_size_1.kn()); + host_tensor_C1.at(i) = HostTensorC(problem_size_1.mn()); + host_tensor_Bias1.at(i) = HostTensorScale(typename HostTensorBias::Layout::TensorCoord{1, problem_size_1.n()}); + host_tensor_D1.at(i) = HostTensorC(problem_size_1.mn()); + host_tensor_ref_D1.at(i) = HostTensorC(problem_size_1.mn()); + + CHECK_TRUE(initialize_tensor(host_tensor_A0.at(i).host_view(), init_A, seed + 2019)); + CHECK_TRUE(initialize_tensor(host_tensor_B0.at(i).host_view(), init_B, seed + 2018)); + CHECK_TRUE(initialize_tensor(host_tensor_C0.at(i).host_view(), init_C, seed + 2017)); + if (alpha0 == ElementCompute(0)) //per-channel scale + CHECK_TRUE(initialize_tensor(host_tensor_Scale0.at(i).host_view(), init_Scale, seed + 2014)); + CHECK_TRUE(initialize_tensor(host_tensor_Bias0.at(i).host_view(), init_Bias, seed + 2013)); + CHECK_TRUE(initialize_tensor(host_tensor_B1.at(i).host_view(), init_B, seed + 2016)); + CHECK_TRUE(initialize_tensor(host_tensor_C1.at(i).host_view(), init_C, seed + 2015)); + CHECK_TRUE(initialize_tensor(host_tensor_Bias1.at(i).host_view(), init_Bias, seed + 2012)); + + cutlass::reference::host::TensorFill( + host_tensor_D1.at(i).host_view()); + cutlass::reference::host::TensorFill( + host_tensor_ref_D0.at(i).host_view()); + cutlass::reference::host::TensorFill( + host_tensor_ref_D1.at(i).host_view()); + + host_tensor_A0.at(i).sync_device(); + host_tensor_B0.at(i).sync_device(); + host_tensor_C0.at(i).sync_device(); + if (alpha0 == ElementCompute(0)) //per-channel scale + host_tensor_Scale0.at(i).sync_device(); + host_tensor_Bias0.at(i).sync_device(); + host_tensor_B1.at(i).sync_device(); + host_tensor_C1.at(i).sync_device(); + host_tensor_Bias1.at(i).sync_device(); + host_tensor_D1.at(i).sync_device(); + host_tensor_ref_D0.at(i).sync_device(); + host_tensor_ref_D1.at(i).sync_device(); + + ref_A0.at(i) = (host_tensor_A0.at(i).device_ref()); + ref_B0.at(i) = (host_tensor_B0.at(i).device_ref()); + ref_C0.at(i) = (host_tensor_C0.at(i).device_ref()); + if (alpha0 == ElementCompute(0)) //per-channel scale + ref_Scale0.at(i) = (host_tensor_Scale0.at(i).device_ref()); + ref_Bias0.at(i) = (host_tensor_Bias0.at(i).device_ref()); + ref_B1.at(i) = (host_tensor_B1.at(i).device_ref()); + ref_C1.at(i) = {host_tensor_Bias1.at(i).device_data(), typename B2bGemm::LayoutC::Stride(0)}; + ref_Bias1.at(i) = (host_tensor_Bias1.at(i).device_ref()); + ref_D1.at(i) = (host_tensor_D1.at(i).device_ref()); + ref_Z.at(i) = (host_tensor_Z.at(i).device_ref()); + ref_ref_D0.at(i) = (host_tensor_ref_D0.at(i).device_ref()); + ref_ref_D1.at(i) = (host_tensor_ref_D1.at(i).device_ref()); + } + + // + // Initialize the GEMM operator + // + + cutlass::DeviceAllocation device_ref_A0(problem_count); + device_ref_A0.copy_from_host(ref_A0.data()); + cutlass::DeviceAllocation device_ref_B0(problem_count); + device_ref_B0.copy_from_host(ref_B0.data()); + cutlass::DeviceAllocation device_ref_C0(problem_count); + device_ref_C0.copy_from_host(ref_C0.data()); + cutlass::DeviceAllocation device_ref_Scale0(problem_count); + device_ref_Scale0.copy_from_host(ref_Scale0.data()); + cutlass::DeviceAllocation device_ref_Bias0(problem_count); + device_ref_Bias0.copy_from_host(ref_Bias0.data()); + cutlass::DeviceAllocation device_ref_B1(problem_count); + device_ref_B1.copy_from_host(ref_B1.data()); + cutlass::DeviceAllocation device_ref_C1(problem_count); + device_ref_C1.copy_from_host(ref_C1.data()); + cutlass::DeviceAllocation device_ref_Bias1(problem_count); + device_ref_Bias1.copy_from_host(ref_Bias1.data()); + cutlass::DeviceAllocation device_ref_D1(problem_count); + device_ref_D1.copy_from_host(ref_D1.data()); + + cutlass::DeviceAllocation device_problem_sizes_0(problem_count); + device_problem_sizes_0.copy_from_host(problem_sizes_0.data()); + cutlass::DeviceAllocation device_problem_sizes_1(problem_count); + device_problem_sizes_1.copy_from_host(problem_sizes_1.data()); + + B2bGemm b2b_gemm_op; + + int threadblock_count = B2bGemm::sufficient(problem_sizes_1.data(), problem_count); + if (!threadblock_count) { + std::cout << "Active CUDA device lacks hardware resources to run CUTLASS Grouped GEMM kernel." << std::endl; + return false; + } + + typename B2bGemm::Arguments arguments{ + problem_count, + device_problem_sizes_0.get(), + device_problem_sizes_1.get(), + device_ref_A0.get(), + device_ref_B0.get(), + device_ref_C0.get(), + device_ref_Scale0.get(), + device_ref_Bias0.get(), + device_ref_B1.get(), + device_ref_C1.get(), + device_ref_D1.get(), + {alpha0, beta0}, + {alpha1, beta1}, + threadblock_count + }; + + cutlass::Status status = b2b_gemm_op.can_implement(arguments); + + if(status != cutlass::Status::kSuccess) { + std::cout << "Problem sizes not supported.\n" + << "Requirments:\n" + << " problem_size_0.M = problem_size_1.M\n" + << " problem_size_0.N = problem_size_1.K\n" + << " ThreadblockShape0::kN = problem_size_0.N\n" + << " ThreadblockShape1::kN = problem_size_1.N" << std::endl; + } + + status = b2b_gemm_op.initialize(arguments); + + CUTLASS_CHECK(status); + + for(int i = 0; i < warm_ups; i++) { + status = b2b_gemm_op(); + CUTLASS_CHECK(status); + } + + // + // Run the GEMM + // + + cudaEvent_t start, stop; + cudaEventCreate(&start); + cudaEventCreate(&stop); + + cudaEventRecord(start); + + for(int i = 0; i < runs; i++) { + status = b2b_gemm_op(); + CUTLASS_CHECK(status); + } + + cudaEventRecord(stop); + cudaDeviceSynchronize(); + float gemmTime; + cudaEventElapsedTime(&gemmTime, start, stop); + std::cout << "Fusion time " << gemmTime / (float)runs << " ms\n"; + + for (int i = 0; i < problem_count; ++i) { + host_tensor_D1.at(i).sync_host(); + + // + // Verify + // + + cutlass::reference::device::Gemm< + typename B2bGemm::ElementA, typename B2bGemm::LayoutA, + typename B2bGemm::ElementB, typename B2bGemm::LayoutB, + ElementAccumulator, typename B2bGemm::LayoutC, + ElementAccumulator, ElementAccumulator> + reference_gemm_0; + + cutlass::reference::device::Gemm< + typename B2bGemm::ElementA, typename B2bGemm::LayoutA, + typename B2bGemm::ElementB, typename B2bGemm::LayoutB, + typename B2bGemm::ElementC, typename B2bGemm::LayoutC, ElementCompute, + ElementAccumulator> + reference_gemm_1; + + auto problem_size_0 = problem_sizes_0[i]; + auto problem_size_1 = problem_sizes_1[i]; + + reference_gemm_0( + problem_size_0, + ElementAccumulator(1), //intermediate alpha=1 + ref_A0.at(i), + ref_B0.at(i), + ElementAccumulator(0), //beta = 0 + ref_Z.at(i), + ref_Z.at(i), + ElementAccumulator(0) + ); + + cutlass::reference::device::TensorScaleBiasGemm< + ElementAccumulator, typename B2bGemm::ElementC, typename B2bGemm::LayoutC, + ElementCompute, typename B2bGemm::LayoutC + > ( + problem_size_0, + ref_Z.at(i), + ref_ref_D0.at(i), + alpha0, + ref_Scale0.at(i), + ref_Bias0.at(i) + ); + + if(relu) { + cutlass::reference::device::TensorReLu(host_tensor_ref_D0.at(i).device_view()); + } + + reference_gemm_1( + problem_size_1, + alpha1, + ref_ref_D0.at(i), + ref_B1.at(i), + beta1, + {host_tensor_Bias1.at(i).device_data(), typename B2bGemm::LayoutC::Stride(0)}, + ref_ref_D1.at(i) + ); + if(relu) { + cutlass::reference::device::TensorReLu(host_tensor_ref_D1.at(i).device_view()); + } + cudaDeviceSynchronize(); + host_tensor_ref_D0.at(i).sync_host(); + host_tensor_ref_D1.at(i).sync_host(); + + CHECK_GT(cutlass::reference::host::TensorNorm(host_tensor_ref_D0.at(i).host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(host_tensor_D1.at(i).host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(host_tensor_ref_D1.at(i).host_view()), 0); + + bool passed = cutlass::reference::host::TensorEquals( + host_tensor_ref_D1.at(i).host_view(), + host_tensor_D1.at(i).host_view()); + + CHECK_TRUE(passed); + if (!passed) + { + + std::stringstream fname; + + fname << "error_B2bGemm_device_fused.txt"; + std::cerr << "Check failed for GEMM " << i << " in the group." << std::endl; + std::cerr << "Dumping results in " << fname.str() << "\n"; + + std::ofstream file(fname.str()); + + file + << "GEMM " << i << " in group\n" + << "A0 =\n" << host_tensor_A0.at(i).host_view() + << "\nB0 =\n" << host_tensor_B0.at(i).host_view() + << "\nC0 =\n" << host_tensor_C0.at(i).host_view() + << "\nScale0:\n" << host_tensor_Scale0.at(i).host_view() << "\n" + << "\nBias0:\n" << host_tensor_Bias0.at(i).host_view() << "\n" + << "\nB1 =\n" << host_tensor_B1.at(i).host_view() + << "\nC1 =\n" << host_tensor_C1.at(i).host_view() + << "\nBias1:\n" << host_tensor_Bias1.at(i).host_view() << "\n" + << "\n\nReference =\n" << host_tensor_ref_D1.at(i).host_view() + << "\nComputed =\n" << host_tensor_D1.at(i).host_view(); + + return false; + } + } + return true; + } + +}; + +//////////////////////////////////////////////////////////////////////////////// diff --git a/examples/13_two_tensor_op_fusion/b2b_interleaved_conv2d_run.h b/examples/13_two_tensor_op_fusion/b2b_interleaved_conv2d_run.h index f9905fa521..f70c21af8b 100644 --- a/examples/13_two_tensor_op_fusion/b2b_interleaved_conv2d_run.h +++ b/examples/13_two_tensor_op_fusion/b2b_interleaved_conv2d_run.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/examples/13_two_tensor_op_fusion/b2b_interleaved_gemm_run.h b/examples/13_two_tensor_op_fusion/b2b_interleaved_gemm_run.h index 95c404d9ec..43a33b12f7 100644 --- a/examples/13_two_tensor_op_fusion/b2b_interleaved_gemm_run.h +++ b/examples/13_two_tensor_op_fusion/b2b_interleaved_gemm_run.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -43,6 +43,7 @@ #include "cutlass/util/reference/host/tensor_norm.h" #include "cutlass/util/host_reorder.h" #include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/gemm_complex.h" #include "cutlass/util/reference/device/tensor_relu.h" #include "reference/device/tensor_scale_bias.h" @@ -76,9 +77,9 @@ struct B2bInterleavedNonFusedGemmRun // B2bInterleavedNonFusedGemmRun( - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform, uint64_t seed_ = 2080 ): @@ -87,7 +88,7 @@ struct B2bInterleavedNonFusedGemmRun /// Helper to initialize a tensor view template bool initialize_tensor( - cutlass::TensorView view, + cutlass::TensorView view, cutlass::Distribution::Kind dist_kind, uint64_t seed) { @@ -95,7 +96,7 @@ struct B2bInterleavedNonFusedGemmRun cutlass::reference::host::TensorFillRandomUniform( view, seed, 2, -2, 0); - } + } else if (dist_kind == cutlass::Distribution::Identity) { cutlass::reference::host::TensorFillIdentity(view); @@ -128,73 +129,72 @@ struct B2bInterleavedNonFusedGemmRun /// Executes one test bool run( - cutlass::gemm::GemmCoord problem_size_0, - cutlass::gemm::GemmCoord problem_size_1, - ElementCompute alpha0 = ElementCompute(1), + cutlass::gemm::GemmCoord problem_size_0, + cutlass::gemm::GemmCoord problem_size_1, + ElementCompute alpha0 = ElementCompute(1), ElementCompute beta0 = ElementCompute(0), - ElementCompute alpha1 = ElementCompute(1), + ElementCompute alpha1 = ElementCompute(1), ElementCompute beta1 = ElementCompute(0), bool relu = true, int warm_ups = 1, int runs = 100) { - + // // Allocate the GEMM workspace // cutlass::HostTensor< - typename Gemm0::ElementA, + typename Gemm0::ElementA, typename Gemm0::LayoutA> tensor_A0(problem_size_0.mk()); cutlass::HostTensor< - typename Gemm0::ElementB, + typename Gemm0::ElementB, typename Gemm0::LayoutB> tensor_B0(problem_size_0.kn()); cutlass::HostTensor< - typename Gemm0::ElementB, + typename Gemm0::ElementB, typename Gemm0::LayoutB> tensor_B0_reordered(problem_size_0.kn()); cutlass::HostTensor< - typename Gemm0::ElementC, + typename Gemm0::ElementC, typename Gemm0::LayoutC> tensor_C0(problem_size_0.mn()); cutlass::HostTensor< - typename Gemm0::ElementC, + typename Gemm0::ElementC, typename Gemm0::LayoutC> tensor_Bias0({1, problem_size_0.n()}); cutlass::HostTensor< - typename Gemm0::ElementC, + typename Gemm0::ElementC, typename Gemm0::LayoutC> tensor_D0(problem_size_0.mn()); cutlass::HostTensor< - typename Gemm0::ElementC, + typename Gemm0::ElementC, typename Gemm0::LayoutC> reference_D0(problem_size_0.mn()); cutlass::HostTensor< - typename Gemm1::ElementB, + typename Gemm1::ElementB, typename Gemm1::LayoutB> tensor_B1(problem_size_1.kn()); cutlass::HostTensor< - typename Gemm1::ElementB, + typename Gemm1::ElementB, typename Gemm1::LayoutB> tensor_B1_reordered(problem_size_1.kn()); cutlass::HostTensor< - typename Gemm1::ElementC, + typename Gemm1::ElementC, typename Gemm1::LayoutC> tensor_C1(problem_size_1.mn()); cutlass::HostTensor< - typename Gemm0::ElementC, + typename Gemm0::ElementC, typename Gemm1::LayoutC> tensor_Bias1({1, problem_size_1.n()}); cutlass::HostTensor< - typename Gemm1::ElementC, + typename Gemm1::ElementC, typename Gemm1::LayoutC> tensor_D1(problem_size_1.mn()); cutlass::HostTensor< - typename Gemm1::ElementC, + typename Gemm1::ElementC, typename Gemm1::LayoutC> reference_D1(problem_size_1.mn()); - CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019)); CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018)); CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017)); @@ -285,13 +285,13 @@ struct B2bInterleavedNonFusedGemmRun for(int i = 0; i < runs; i++) { status = gemm_op_0(); - + CUTLASS_CHECK(status); } cudaEventRecord(stop1); for(int i = 0; i < runs; i++) { status = gemm_op_1(); - + CUTLASS_CHECK(status); } @@ -327,36 +327,36 @@ struct B2bInterleavedNonFusedGemmRun reference_gemm_0( problem_size_0, - alpha0, - tensor_A0.device_ref(), - tensor_B0.device_ref(), - beta0, + alpha0, + tensor_A0.device_ref(), + tensor_B0.device_ref(), + beta0, {tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)}, reference_D0.device_ref() ); if(relu) { - cutlass::reference::device::TensorReLu(reference_D0.device_view()); + cutlass::reference::device::TensorReLu(reference_D0.device_view()); } reference_gemm_1( problem_size_1, - alpha1, - reference_D0.device_ref(), - tensor_B1.device_ref(), + alpha1, + reference_D0.device_ref(), + tensor_B1.device_ref(), beta1, {tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)}, reference_D1.device_ref() ); - + if(relu) { - cutlass::reference::device::TensorReLu(reference_D1.device_view()); + cutlass::reference::device::TensorReLu(reference_D1.device_view()); } // Wait for kernels to finish cudaDeviceSynchronize(); - reference_D0.sync_host(); - reference_D1.sync_host(); + reference_D0.sync_host(); + reference_D1.sync_host(); CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0.host_view()), 0); CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0); @@ -364,7 +364,7 @@ struct B2bInterleavedNonFusedGemmRun CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0); bool passed = cutlass::reference::host::TensorEquals( - reference_D1.host_view(), + reference_D1.host_view(), tensor_D1.host_view()); CHECK_TRUE(passed); @@ -377,7 +377,7 @@ struct B2bInterleavedNonFusedGemmRun std::ofstream file(fname.str()); - file + file << "A0 =\n" << tensor_A0.host_view() << "\nB0 =\n" << tensor_B0.host_view() << "\nB0_reordered =\n" << tensor_B0_reordered.host_view() @@ -416,9 +416,9 @@ struct B2bInterleavedFusedGemmRun // B2bInterleavedFusedGemmRun( - cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, - cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform, uint64_t seed_ = 2080 @@ -429,7 +429,7 @@ struct B2bInterleavedFusedGemmRun /// Helper to initialize a tensor view template bool initialize_tensor( - cutlass::TensorView view, + cutlass::TensorView view, cutlass::Distribution::Kind dist_kind, uint64_t seed) { @@ -437,11 +437,11 @@ struct B2bInterleavedFusedGemmRun cutlass::reference::host::TensorFillRandomUniform( view, seed, 2, -2, 0); - } + } else if (dist_kind == cutlass::Distribution::Identity) { cutlass::reference::host::TensorFillIdentity(view); - } + } else if (dist_kind == cutlass::Distribution::Gaussian) { cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); @@ -470,78 +470,99 @@ struct B2bInterleavedFusedGemmRun /// Executes one test bool run( - cutlass::gemm::GemmCoord problem_size_0, - cutlass::gemm::GemmCoord problem_size_1, - ElementCompute alpha0 = ElementCompute(1), + cutlass::gemm::GemmCoord problem_size_0, + cutlass::gemm::GemmCoord problem_size_1, + ElementCompute alpha0 = ElementCompute(1), ElementCompute beta0 = ElementCompute(0), - ElementCompute alpha1 = ElementCompute(1), + ElementCompute alpha1 = ElementCompute(1), ElementCompute beta1 = ElementCompute(0), + cutlass::gemm::GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm, + + // batch_count is used as split-k when mode is kGemm according + // to the GemmUniversal interface + + int batch_count = 1, + + int64_t batch_stride_A0 = 0, + int64_t batch_stride_B0 = 0, + int64_t batch_stride_C0 = 0, + int64_t batch_stride_B1 = 0, + int64_t batch_stride_C1 = 0, + int64_t batch_stride_D1 = 0, + int64_t batch_stride_Bias0 = 0, + int64_t batch_stride_Scale0 = 0, bool relu = true, int warm_ups = 1, int runs = 100) { - + // // Allocate the GEMM workspace // + cutlass::gemm::GemmCoord CoordA0(problem_size_0.m(), problem_size_0.n(), batch_count * problem_size_0.k()); + cutlass::gemm::GemmCoord CoordB0(problem_size_0.m(), problem_size_0.n(), batch_count * problem_size_0.k()); + cutlass::gemm::GemmCoord CoordC0(problem_size_0.m(), batch_count * problem_size_0.n(), problem_size_0.k()); + cutlass::gemm::GemmCoord CoordB1(problem_size_1.m(), problem_size_1.n(), batch_count * problem_size_1.k()); + cutlass::gemm::GemmCoord CoordC1(problem_size_1.m(), batch_count * problem_size_1.n(), problem_size_1.k()); + cutlass::HostTensor< - typename B2bGemm::ElementA, - typename B2bGemm::LayoutA> tensor_A0(problem_size_0.mk()); + typename B2bGemm::ElementA, + typename B2bGemm::LayoutA> tensor_A0(CoordA0.mk()); cutlass::HostTensor< - typename B2bGemm::ElementB, - typename B2bGemm::LayoutB> tensor_B0(problem_size_0.kn()); + typename B2bGemm::ElementB, + typename B2bGemm::LayoutB> tensor_B0(CoordB0.kn()); cutlass::HostTensor< - typename B2bGemm::ElementB, - typename B2bGemm::LayoutB> tensor_B0_reordered(problem_size_0.kn()); + typename B2bGemm::ElementB, + typename B2bGemm::LayoutB> tensor_B0_reordered(CoordB0.kn()); cutlass::HostTensor< - typename B2bGemm::ElementC, - typename B2bGemm::LayoutC> tensor_C0(problem_size_0.mn()); + typename B2bGemm::ElementC, + typename B2bGemm::LayoutC> tensor_C0(CoordC0.mn()); cutlass::HostTensor< - typename B2bGemm::ElementScaleBias, + typename B2bGemm::ElementScaleBias, typename B2bGemm::LayoutScaleBias> tensor_Scale0; if(alpha0 == ElementCompute(0)) //per-channel scale - tensor_Scale0.resize({1, problem_size_0.n()}); + tensor_Scale0.resize({1, batch_count * problem_size_0.n()}); cutlass::HostTensor< - typename B2bGemm::ElementScaleBias, - typename B2bGemm::LayoutScaleBias> tensor_Bias0({1, problem_size_0.n()}); + typename B2bGemm::ElementScaleBias, + typename B2bGemm::LayoutScaleBias> tensor_Bias0({1, batch_count * problem_size_0.n()}); cutlass::HostTensor< - ElementAccumulator, - typename B2bGemm::LayoutC> reference_Z0(problem_size_0.mn()); + ElementAccumulator, + typename B2bGemm::LayoutC> reference_Z0(CoordC0.mn()); cutlass::HostTensor< - typename B2bGemm::ElementC, - typename B2bGemm::LayoutC> reference_D0(problem_size_0.mn()); + typename B2bGemm::ElementC, + typename B2bGemm::LayoutC> reference_D0(CoordC0.mn()); cutlass::HostTensor< - typename B2bGemm::ElementB, - typename B2bGemm::LayoutB> tensor_B1(problem_size_1.kn()); + typename B2bGemm::ElementB, + typename B2bGemm::LayoutB> tensor_B1(CoordB1.kn()); cutlass::HostTensor< - typename B2bGemm::ElementB, - typename B2bGemm::LayoutB> tensor_B1_reordered(problem_size_1.kn()); + typename B2bGemm::ElementB, + typename B2bGemm::LayoutB> tensor_B1_reordered(CoordB1.kn()); cutlass::HostTensor< - typename B2bGemm::ElementC, - typename B2bGemm::LayoutC> tensor_C1(problem_size_1.mn()); + typename B2bGemm::ElementC, + typename B2bGemm::LayoutC> tensor_C1(CoordC1.mn()); cutlass::HostTensor< - typename B2bGemm::ElementC, - typename B2bGemm::LayoutScaleBias> tensor_Bias1({1, problem_size_1.n()}); + typename B2bGemm::ElementC, + typename B2bGemm::LayoutScaleBias> tensor_Bias1({1, batch_count * problem_size_1.n()}); cutlass::HostTensor< - typename B2bGemm::ElementC, - typename B2bGemm::LayoutC> tensor_D1(problem_size_1.mn()); + typename B2bGemm::ElementC, + typename B2bGemm::LayoutC> tensor_D1(CoordC1.mn()); cutlass::HostTensor< - typename B2bGemm::ElementC, - typename B2bGemm::LayoutC> reference_D1(problem_size_1.mn()); + typename B2bGemm::ElementC, + typename B2bGemm::LayoutC> reference_D1(CoordC1.mn()); CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019)); @@ -556,9 +577,9 @@ struct B2bInterleavedFusedGemmRun //Reorder B0 cutlass::reorder_column<16>( - tensor_B0_reordered.host_ref(), tensor_B0.host_ref(), problem_size_0); + tensor_B0_reordered.host_ref(), tensor_B0.host_ref(), CoordB0); cutlass::reorder_column( - tensor_B1_reordered.host_ref(), tensor_B1.host_ref(), problem_size_1); + tensor_B1_reordered.host_ref(), tensor_B1.host_ref(), CoordB1); cutlass::reference::host::TensorFill( tensor_D1.host_view()); @@ -581,12 +602,14 @@ struct B2bInterleavedFusedGemmRun tensor_D1.sync_device(); reference_D0.sync_device(); reference_D1.sync_device(); + // tensor_Bias0_batched.sync_device(); // // Initialize the GEMM operator // typename B2bGemm::Arguments arguments{ + mode, problem_size_0, problem_size_1, tensor_A0.device_ref(), @@ -597,8 +620,16 @@ struct B2bInterleavedFusedGemmRun tensor_B1_reordered.device_ref(), {tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)}, tensor_D1.device_ref(), + batch_stride_A0, + batch_stride_B0, + batch_stride_B1, + batch_stride_C1, + batch_stride_D1, + batch_stride_Bias0, + batch_stride_Scale0, {alpha0, beta0}, {alpha1, beta1}, + batch_count, }; B2bGemm b2b_gemm_op; @@ -651,32 +682,30 @@ struct B2bInterleavedFusedGemmRun // Verify // - cutlass::reference::device::Gemm< - typename B2bGemm::ElementA, typename B2bGemm::LayoutA, - typename B2bGemm::ElementB, typename B2bGemm::LayoutB, - ElementAccumulator, typename B2bGemm::LayoutC, - ElementAccumulator, ElementAccumulator> - reference_gemm_0; - - cutlass::reference::device::Gemm< - typename B2bGemm::ElementA, typename B2bGemm::LayoutA, - typename B2bGemm::ElementB, typename B2bGemm::LayoutB, - typename B2bGemm::ElementC, typename B2bGemm::LayoutC, ElementCompute, - ElementAccumulator, typename B2bGemm::Operator> - reference_gemm_1; - - reference_gemm_0( + cutlass::reference::device::GemmComplex< + typename B2bGemm::ElementA, typename B2bGemm::LayoutA, + typename B2bGemm::ElementB, typename B2bGemm::LayoutB, + ElementAccumulator, typename B2bGemm::LayoutC, + ElementAccumulator, ElementAccumulator + >( problem_size_0, ElementAccumulator(1), //intermediate alpha=1 - tensor_A0.device_ref(), - tensor_B0.device_ref(), + tensor_A0.device_ref(), + cutlass::ComplexTransform::kNone, + tensor_B0.device_ref(), + cutlass::ComplexTransform::kNone, ElementAccumulator(0), //beta = 0 reference_Z0.device_ref(), reference_Z0.device_ref(), - ElementAccumulator(0) + ElementAccumulator(0), + int(batch_count), + batch_stride_A0, + batch_stride_B0, + batch_stride_C0, + batch_stride_C0 ); - cutlass::reference::device::TensorScaleBiasGemm< + cutlass::reference::device::TensorScaleBiasGemmBatched< ElementAccumulator, typename B2bGemm::ElementC, typename B2bGemm::LayoutC, ElementCompute, typename B2bGemm::LayoutScaleBias > ( @@ -685,25 +714,45 @@ struct B2bInterleavedFusedGemmRun reference_D0.device_ref(), alpha0, tensor_Scale0.device_ref(), - tensor_Bias0.device_ref() + tensor_Bias0.device_ref(), + int(batch_count), + batch_stride_C0, + batch_stride_C0, + batch_stride_Scale0, + batch_stride_Bias0 ); if(relu) { - cutlass::reference::device::TensorReLu(reference_D0.device_view()); + cutlass::reference::device::TensorReLu(reference_D0.device_view()); } - reference_gemm_1( + cutlass::reference::device::GemmComplex< + typename B2bGemm::ElementA, typename B2bGemm::LayoutA, + typename B2bGemm::ElementB, typename B2bGemm::LayoutB, + typename B2bGemm::ElementC, typename B2bGemm::LayoutC, + ElementCompute, ElementAccumulator + >( problem_size_1, - alpha1, - reference_D0.device_ref(), - tensor_B1.device_ref(), - beta1, + alpha1, //intermediate alpha=1 + reference_D0.device_ref(), + cutlass::ComplexTransform::kNone, + tensor_B1.device_ref(), + cutlass::ComplexTransform::kNone, + beta1, //beta = 0 {tensor_Bias1.device_data(), typename B2bGemm::LayoutC::Stride(0)}, - reference_D1.device_ref() + reference_D1.device_ref(), + ElementAccumulator(0), + int(batch_count), + batch_stride_C0, + batch_stride_B1, + batch_stride_C1, + batch_stride_D1 ); + if(relu) { - cutlass::reference::device::TensorReLu(reference_D1.device_view()); + cutlass::reference::device::TensorReLu(reference_D1.device_view()); } + cudaDeviceSynchronize(); reference_D0.sync_host(); reference_D1.sync_host(); @@ -713,7 +762,7 @@ struct B2bInterleavedFusedGemmRun CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0); bool passed = cutlass::reference::host::TensorEquals( - reference_D1.host_view(), + reference_D1.host_view(), tensor_D1.host_view()); CHECK_TRUE(passed); @@ -727,7 +776,7 @@ struct B2bInterleavedFusedGemmRun std::ofstream file(fname.str()); - file + file << "A0 =\n" << tensor_A0.host_view() << "\nB0 =\n" << tensor_B0.host_view() << "\nB0_reordered =\n" << tensor_B0_reordered.host_view() diff --git a/examples/13_two_tensor_op_fusion/device/b2b_gemm.h b/examples/13_two_tensor_op_fusion/device/b2b_gemm.h index 3751cc82b8..338090764e 100644 --- a/examples/13_two_tensor_op_fusion/device/b2b_gemm.h +++ b/examples/13_two_tensor_op_fusion/device/b2b_gemm.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -119,8 +119,6 @@ template < int AlignmentB = DefaultGemmConfiguration::kAlignmentB, - /// If true, kernel supports split-K with serial reduction - bool SplitKSerial = false, /// Operation performed by GEMM typename Operator_ = typename DefaultGemmConfiguration< OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, @@ -154,7 +152,6 @@ class B2bGemm { static int const kAlignmentA = AlignmentA; static int const kAlignmentB = AlignmentB; static int const kAlignmentC = EpilogueOutputOp1::kCount; - static bool const kSplitKSerial = SplitKSerial; static ComplexTransform const kTransformA = ComplexTransform::kNone; static ComplexTransform const kTransformB = ComplexTransform::kNone; @@ -184,77 +181,11 @@ class B2bGemm { EpilogueOutputOp1, ThreadblockSwizzle, kStages, - kSplitKSerial, Operator, SmemAccumulator >::B2bGemmKernel; - /// Argument structure - struct Arguments { - - // - // Data members - // - - GemmCoord problem_size_0; - GemmCoord problem_size_1; - TensorRef ref_A0; - TensorRef ref_B0; - TensorRef ref_C0; - TensorRef ref_Scale0; - TensorRef ref_Bias0; - TensorRef ref_B1; - TensorRef ref_C1; - TensorRef ref_D1; - typename EpilogueOutputOp0::Params epilogue0; - typename EpilogueOutputOp1::Params epilogue1; - int split_k_slices; - - // - // Methods - // - - /// Default ctor - CUTLASS_HOST_DEVICE - Arguments(): problem_size_0(0, 0, 0), problem_size_1(0, 0, 0), split_k_slices(1) { - - } - - /// Constructs an Arguments structure - CUTLASS_HOST_DEVICE - Arguments( - GemmCoord problem_size_0_, - GemmCoord problem_size_1_, - TensorRef ref_A0_, - TensorRef ref_B0_, - TensorRef ref_C0_, - TensorRef ref_Scale0_, - TensorRef ref_Bias0_, - TensorRef ref_B1_, - TensorRef ref_C1_, - TensorRef ref_D1_, - typename EpilogueOutputOp0::Params epilogue0_ = - typename EpilogueOutputOp0::Params(), - typename EpilogueOutputOp1::Params epilogue1_ = - typename EpilogueOutputOp1::Params(), - int split_k_slices_ = 1 - ): - problem_size_0(problem_size_0_), - problem_size_1(problem_size_1_), - ref_A0(ref_A0_), - ref_B0(ref_B0_), - ref_C0(ref_C0_), - ref_Scale0(ref_Scale0_), - ref_Bias0(ref_Bias0_), - ref_B1(ref_B1_), - ref_C1(ref_C1_), - ref_D1(ref_D1_), - epilogue0(epilogue0_), - epilogue1(epilogue1_), - split_k_slices(split_k_slices_) { - - } - }; + using Arguments = typename B2bGemmKernel::Arguments; private: @@ -269,10 +200,6 @@ class B2bGemm { /// Determines whether the GEMM can execute the given problem. static Status can_implement(Arguments const &args) { - if (!kSplitKSerial && args.split_k_slices > 1) { - return Status::kErrorInvalidProblem; - } - Status status = B2bGemmKernel::can_implement( args.problem_size_0, args.problem_size_1, @@ -295,20 +222,14 @@ class B2bGemm { static size_t get_workspace_size(Arguments const &args) { size_t bytes = 0; - + // Determine grid shape ThreadblockSwizzle threadblock_swizzle; cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( - args.problem_size_0, + args.problem_size_0, {ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK}, - args.split_k_slices); - - if (kSplitKSerial && args.split_k_slices > 1) { - - - bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n()); - } + args.batch_count); return bytes; } @@ -320,38 +241,17 @@ class B2bGemm { ThreadblockSwizzle threadblock_swizzle; cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( - args.problem_size_0, + args.problem_size_0, {ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK}, - args.split_k_slices); + args.batch_count); // cutlass::gemm::GemmCoord grid_shape_1 = threadblock_swizzle.get_tiled_shape( -// args.problem_size_1, +// args.problem_size_1, // {ThreadblockShape1::kM, ThreadblockShape1::kN, ThreadblockShape1::kK}, -// args.split_k_slices); - - if (kSplitKSerial) { - if (args.split_k_slices > 1) { - if (!workspace) { - return Status::kErrorWorkspaceNull; - } - - size_t bytes = get_workspace_size(args); - - cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream); - - if (result != cudaSuccess) { - return Status::kErrorInternal; - } - } - } - else { - - if (args.split_k_slices > 1) { - return Status::kErrorInvalidProblem; - } - } +// args.batch_count); // Initialize the Params structure params_ = typename B2bGemmKernel::Params{ + args.mode, args.problem_size_0, args.problem_size_1, grid_shape, @@ -363,6 +263,13 @@ class B2bGemm { args.ref_B1.non_const_ref(), args.ref_C1.non_const_ref(), args.ref_D1, + args.batch_stride_A0, + args.batch_stride_B0, + args.batch_stride_B1, + args.batch_stride_C1, + args.batch_stride_D1, + args.batch_stride_Bias0, + args.batch_stride_Scale0, args.epilogue0, args.epilogue1, static_cast(workspace), @@ -373,12 +280,6 @@ class B2bGemm { /// Lightweight update given a subset of arguments Status update(Arguments const &args, void *workspace = nullptr) { - - if (kSplitKSerial && args.split_k_slices > 1) { - if (!workspace) { - return Status::kErrorWorkspaceNull; - } - } params_.ref_A0.reset(args.ref_A0.non_const_ref().data()); params_.ref_B0.reset(args.ref_B0.non_const_ref().data()); @@ -430,12 +331,12 @@ class B2bGemm { /// Runs the kernel using initialized state. Status operator()( - Arguments const &args, - void *workspace = nullptr, + Arguments const &args, + void *workspace = nullptr, cudaStream_t stream = nullptr) { - + Status status = initialize(args, workspace, stream); - + if (status == Status::kSuccess) { status = run(stream); } diff --git a/examples/13_two_tensor_op_fusion/device/b2b_implicit_gemm_convolution.h b/examples/13_two_tensor_op_fusion/device/b2b_implicit_gemm_convolution.h index 7dd8fe2889..5d6a0e94f5 100644 --- a/examples/13_two_tensor_op_fusion/device/b2b_implicit_gemm_convolution.h +++ b/examples/13_two_tensor_op_fusion/device/b2b_implicit_gemm_convolution.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm75_rf.cu b/examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm75_rf.cu index d9c59db0e1..9f5b89e550 100644 --- a/examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm75_rf.cu +++ b/examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm75_rf.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm75_shmem.cu b/examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm75_shmem.cu index 54a1315908..cf7133ee1f 100644 --- a/examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm75_shmem.cu +++ b/examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm75_shmem.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm80_rf.cu b/examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm80_rf.cu index 7a66f8e88b..be6d7d54e9 100644 --- a/examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm80_rf.cu +++ b/examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm80_rf.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm80_shmem.cu b/examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm80_shmem.cu index 5a60714160..50c886d1fe 100644 --- a/examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm80_shmem.cu +++ b/examples/13_two_tensor_op_fusion/fused_two_convs_f16_sm80_shmem.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm75_rf.cu b/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm75_rf.cu index 2481fbd82e..5e94c7485a 100644 --- a/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm75_rf.cu +++ b/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm75_rf.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -220,7 +220,6 @@ bool run_fused_conv2d_fprop_optimized_s8_sm75_rf_res() { return pass; } - int main() { std::vectorfuncs = { @@ -229,10 +228,6 @@ int main() { }; return testRun(75, funcs, "conv int8 RF residency"); - } - - //////////////////////////////////////////////////////////////////////////////// - diff --git a/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm75_shmem.cu b/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm75_shmem.cu index 917ae93063..aeea07f2b1 100644 --- a/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm75_shmem.cu +++ b/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm75_shmem.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -39,7 +39,6 @@ #include "device/b2b_implicit_gemm_convolution.h" #include "b2b_interleaved_conv2d_run.h" #include "test_run.h" - //////////////////////////////////////////////////////////////////////////////// cutlass::conv::Conv2dProblemSize conv2d_s8_sm75_problem_size_0 ( @@ -219,20 +218,13 @@ bool run_fused_conv2d_fprop_optimized_s8_sm75_shmem() { return pass; } - - int main() { - std::vectorfuncs = { &run_nonfused_conv2d_fprop_optimized_s8_sm75, &run_fused_conv2d_fprop_optimized_s8_sm75_shmem }; return testRun(75, funcs, "conv int8 shmem staging"); - } - - //////////////////////////////////////////////////////////////////////////////// - diff --git a/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm80_rf.cu b/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm80_rf.cu index a515f1255b..d91df2a64a 100644 --- a/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm80_rf.cu +++ b/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm80_rf.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm80_shmem.cu b/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm80_shmem.cu index 9a5b2c1c56..2b865e6b0a 100644 --- a/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm80_shmem.cu +++ b/examples/13_two_tensor_op_fusion/fused_two_convs_s8_sm80_shmem.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/examples/13_two_tensor_op_fusion/fused_two_gemms_f16_sm75_rf.cu b/examples/13_two_tensor_op_fusion/fused_two_gemms_f16_sm75_rf.cu index 54c8835543..44243c5558 100644 --- a/examples/13_two_tensor_op_fusion/fused_two_gemms_f16_sm75_rf.cu +++ b/examples/13_two_tensor_op_fusion/fused_two_gemms_f16_sm75_rf.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/examples/13_two_tensor_op_fusion/fused_two_gemms_f16_sm75_shmem.cu b/examples/13_two_tensor_op_fusion/fused_two_gemms_f16_sm75_shmem.cu index 30ba26990d..e4709be43d 100644 --- a/examples/13_two_tensor_op_fusion/fused_two_gemms_f16_sm75_shmem.cu +++ b/examples/13_two_tensor_op_fusion/fused_two_gemms_f16_sm75_shmem.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/examples/13_two_tensor_op_fusion/fused_two_gemms_f16_sm80_rf.cu b/examples/13_two_tensor_op_fusion/fused_two_gemms_f16_sm80_rf.cu index 0c2239ac2f..0e64d401c0 100644 --- a/examples/13_two_tensor_op_fusion/fused_two_gemms_f16_sm80_rf.cu +++ b/examples/13_two_tensor_op_fusion/fused_two_gemms_f16_sm80_rf.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/examples/13_two_tensor_op_fusion/fused_two_gemms_f16_sm80_shmem.cu b/examples/13_two_tensor_op_fusion/fused_two_gemms_f16_sm80_shmem.cu index 045e4a8e58..9f6a2a08ec 100644 --- a/examples/13_two_tensor_op_fusion/fused_two_gemms_f16_sm80_shmem.cu +++ b/examples/13_two_tensor_op_fusion/fused_two_gemms_f16_sm80_shmem.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/examples/13_two_tensor_op_fusion/fused_two_gemms_grouped_f16_sm80_rf.cu b/examples/13_two_tensor_op_fusion/fused_two_gemms_grouped_f16_sm80_rf.cu new file mode 100644 index 0000000000..87331d04d3 --- /dev/null +++ b/examples/13_two_tensor_op_fusion/fused_two_gemms_grouped_f16_sm80_rf.cu @@ -0,0 +1,297 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Example of running grouped back-to-back GEMMs when intermediate results are RF resident +*/ + +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/base_grouped.h" +#include "cutlass/gemm/device/gemm.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/gemm.h" + +#include "device/b2b_gemm.h" +#include "kernel/default_b2b_gemm.h" +#include "threadblock/grouped_threadblock_swizzle.h" +#include "b2b_grouped_gemm_run.h" +#include "test_run.h" + +//////////////////////////////////////////////////////////////////////////////// + +std::vector gemm_f16_sm80_problem_sizes_0; +std::vector gemm_f16_sm80_problem_sizes_1; + +// Constraints: +// 1. Warp shape N must equal thread block shape N +// 2. Problem size N must equal thread block shape N +using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 32>; +using WarpShape0 = cutlass::gemm::GemmShape<16, 64, 32>; +using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 32>; +using WarpShape1 = cutlass::gemm::GemmShape<16, 128, 32>; + +// Command line options parsing +struct Options { + + bool help; + bool error; + bool reference_check; + int alignment = 8; + + std::vector problem_sizes0; + std::vector problem_sizes1; + + int problem_count; + bool verbose; + + // + // Methods + // + + Options(): + help(false), + error(false), + reference_check(true), + problem_count(15), + verbose(false) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("problems", problem_count, 15); + cmd.get_cmd_line_argument("reference-check", reference_check, true); + cmd.get_cmd_line_argument("verbose", verbose, false); + + randomize_problems(cmd); + } + + void randomize_problems(cutlass::CommandLine &cmd) { + + // + // For now, randomly choose the problem sizes. + // + + int cmd_line_m = -1; + int cmd_line_k = -1; + + cmd.get_cmd_line_argument("m", cmd_line_m); + cmd.get_cmd_line_argument("k", cmd_line_k); + + problem_sizes0.reserve(problem_count); + problem_sizes1.reserve(problem_count); + + for (int i = 0; i < problem_count; ++i) { + + int m = cmd_line_m; + int k = cmd_line_k; + + if (m < 1) { + m = alignment * ((rand() % 256) + 1); + } + + if (k < 1) { + k = alignment * ((rand() % 256) + 1); + } + + cutlass::gemm::GemmCoord problem0(m, ThreadblockShape0::kN, k); + cutlass::gemm::GemmCoord problem1(m, ThreadblockShape1::kN, ThreadblockShape0::kN); + + problem_sizes0.push_back(problem0); + problem_sizes1.push_back(problem1); + } + + if (verbose) { + print_problem_sizes(); + } + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "13_fused_two_gemms_grouped_f16_sm80_rf\n\n" + << " This example runs a grouped back-to-back GEMM kernel. A group of independent back-to-back GEMMs are\n" + << " run in a single kernel. Each indivdual problem in the group is subject to the same constraints that non-grouped\n" + << " back-to-back GEMMs are subject to.s" + << "Options:\n\n" + << " --help If specified, displays this usage statement.\n\n" + << " --problems= Number of individual GEMM problems (default: --problems=15)\n" + << " --m= Sets the M dimension of both GEMMs for all groups. Otherwise, it is selected randomly\n" + << " --k= Sets the K dimension of the first GEMM for all groups. Otherwise, it is selected randomly\n" + << " --verbose= If true, prints problem sizes.\n"; + + out << "\n\nExamples:\n\n" + + << "# Runs a grouped B2b GEMM with 10 random problem sizes\n" + << "$ ./examples/13_two_tensor_op_fusion/13_fused_two_gemms_grouped_f16_sm80_rf --groups=10\n\n"; + + return out; + } + + void print_problem_sizes() { + std::cout << std::endl; + std::cout << "Executing " << problem_count << " independent back-to-back GEMMs in a group" << std::endl; + for (int i = 0; i < problem_count; ++i) { + cutlass::gemm::GemmCoord problem0 = problem_sizes0.at(i); + cutlass::gemm::GemmCoord problem1 = problem_sizes1.at(i); + std::cout << "Problem " << i + << "\t\tGEMM0: " << problem0.m() << 'x' << problem0.n() << 'x' << problem0.k() + << "\t\tGEMM1: " << problem1.m() << 'x' << problem1.n() << 'x' << problem1.k() + << std::endl; + } + } +}; + +bool run_fused_grouped_gemm_f16_sm80_rf_res() { + + using ElementOutput = cutlass::half_t; + using ElementAccumulator = cutlass::half_t; + using ElementCompute = cutlass::half_t; + + ElementCompute alpha0 = ElementCompute(1); + //Fused kernel has built-in bias, setting beta=0 + ElementCompute beta0 = ElementCompute(0); + ElementCompute alpha1 = ElementCompute(1); + ElementCompute beta1 = ElementCompute(1); //beta=1 for bias + + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + using EpilogueOutputOp0 = + cutlass::epilogue::thread::LinearCombinationRelu< + ElementOutput, + InstructionShape::kM * InstructionShape::kN / 32, + ElementAccumulator, + ElementCompute, + cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling + >; + + using EpilogueOutputOp1 = + cutlass::epilogue::thread::LinearCombinationRelu< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute, + cutlass::epilogue::thread::ScaleType::NoBetaScaling + >; + + using GroupedThreadblockSwizzle = cutlass::gemm::threadblock::B2bGemmGroupedThreadblockSwizzle< + ThreadblockShape0, + cutlass::layout::RowMajor // LayoutC + >; + + const int kAlignment = 128 / cutlass::sizeof_bits::value; + const int kStages = 3; + using B2bGemmKernel = cutlass::gemm::kernel::DefaultB2bGemm< + cutlass::half_t, + cutlass::layout::RowMajor, + kAlignment, + cutlass::half_t, + cutlass::layout::ColumnMajor, + kAlignment, + cutlass::half_t, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + ThreadblockShape0, + ThreadblockShape1, + WarpShape0, + WarpShape1, + InstructionShape, + EpilogueOutputOp0, + EpilogueOutputOp1, + GroupedThreadblockSwizzle, + kStages, + cutlass::arch::OpMultiplyAdd + >::B2bGemmKernel; + + using B2bGemm = cutlass::gemm::device::BaseGrouped; + + B2bFusedGroupedGemmRun fusedGemm; + + std::cout << "Running Fused back-to-back FP16 TN Grouped GEMMs with RF residency...\n"; + bool passed = fusedGemm.run(gemm_f16_sm80_problem_sizes_0, gemm_f16_sm80_problem_sizes_1, alpha0, beta0, alpha1, beta1); + if(passed) + std::cout << "Pass\n"; + else + std::cout << "Fail\n"; + + return passed; +} + +int main(int argc, char const **args) { + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + + gemm_f16_sm80_problem_sizes_0 = options.problem_sizes0; + gemm_f16_sm80_problem_sizes_1 = options.problem_sizes1; + + std::vectorfuncs = { + &run_fused_grouped_gemm_f16_sm80_rf_res + }; + + return testRun(80, funcs, "grouped gemm f16 RF residency"); +} + + + + +//////////////////////////////////////////////////////////////////////////////// diff --git a/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm75_rf.cu b/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm75_rf.cu index 2c00eb86c5..a7f39d2d3d 100644 --- a/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm75_rf.cu +++ b/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm75_rf.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -195,7 +195,6 @@ bool run_fused_gemm_s8_rf_res() { return passed; } - int main() { std::vectorfuncs = { @@ -204,9 +203,6 @@ int main() { }; return testRun(75, funcs, "gemm int8 RF residency"); - - } - //////////////////////////////////////////////////////////////////////////////// diff --git a/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm75_shmem.cu b/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm75_shmem.cu index 10f4cb7b51..671f48b712 100644 --- a/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm75_shmem.cu +++ b/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm75_shmem.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -43,7 +43,6 @@ #include "device/b2b_gemm.h" #include "b2b_interleaved_gemm_run.h" #include "test_run.h" - //////////////////////////////////////////////////////////////////////////////// cutlass::gemm::GemmCoord gemm_s8_sm75_problem_size_0(128*640, 64, 576); @@ -197,18 +196,13 @@ bool run_fused_gemm_s8_shmem() { return passed; } - int main() { std::vectorfuncs = { &run_nonfused_gemm_s8, &run_fused_gemm_s8_shmem }; - return testRun(75, funcs, "gemm int8 shmem staing"); - - } - //////////////////////////////////////////////////////////////////////////////// diff --git a/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm80_rf.cu b/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm80_rf.cu index 38845371b8..b2f12b45f3 100644 --- a/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm80_rf.cu +++ b/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm80_rf.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -152,7 +152,7 @@ bool run_fused_gemm_s8_sm80_rf_res() { using WarpShape1 = cutlass::gemm::GemmShape<16, 128, 64>; using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; - using EpilogueOutputOp0 = + using EpilogueOutputOp0 = cutlass::epilogue::thread::LinearCombinationRelu< ElementOutput, 8 * InstructionShape::kN / 32, @@ -161,7 +161,7 @@ bool run_fused_gemm_s8_sm80_rf_res() { cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling >; - using EpilogueOutputOp1 = + using EpilogueOutputOp1 = cutlass::epilogue::thread::LinearCombinationRelu< ElementOutput, 64 / cutlass::sizeof_bits::value, @@ -194,14 +194,21 @@ bool run_fused_gemm_s8_sm80_rf_res() { SmemAccumulator, 16, 16, - false, cutlass::arch::OpMultiplyAddSaturate >; B2bInterleavedFusedGemmRun fusedGemm; std::cout << "Running Fused back-to-back INT8 NT interleaved GEMMs with RF residency...\n"; - bool passed = fusedGemm.run(gemm_s8_sm80_problem_size_0, gemm_s8_sm80_problem_size_1, alpha0, beta0, alpha1, beta1); + bool passed = fusedGemm.run( + gemm_s8_sm80_problem_size_0, + gemm_s8_sm80_problem_size_1, + alpha0, + beta0, + alpha1, + beta1 + ); + if(passed) std::cout << "Pass\n"; else @@ -210,18 +217,123 @@ bool run_fused_gemm_s8_sm80_rf_res() { return passed; } +bool run_fused_gemm_s8_sm80_rf_res_batch() { + + + cutlass::gemm::GemmCoord gemm_s8_sm80_problem_size_0(256, 64, 128); + cutlass::gemm::GemmCoord gemm_s8_sm80_problem_size_1(256, 128, 64); + + using ElementOutput = int8_t; + using ElementAccumulator = int32_t; + using ElementCompute = float; + + ElementCompute alpha0 = ElementCompute(1); + //Fused kernel has built-in bias, setting beta=0 + ElementCompute beta0 = ElementCompute(0); + ElementCompute alpha1 = ElementCompute(1); + ElementCompute beta1 = ElementCompute(1); //beta=1 for bias + + using ThreadblockShape0 = cutlass::gemm::GemmShape<64, 64, 64>; + using WarpShape0 = cutlass::gemm::GemmShape<16, 64, 64>; + using ThreadblockShape1 = cutlass::gemm::GemmShape<64, 128, 64>; + + using WarpShape1 = cutlass::gemm::GemmShape<16, 128, 64>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; + + using EpilogueOutputOp0 = + cutlass::epilogue::thread::LinearCombinationRelu< + ElementOutput, + 8 * InstructionShape::kN / 32, + ElementAccumulator, + ElementCompute, + cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling + >; + + using EpilogueOutputOp1 = + cutlass::epilogue::thread::LinearCombinationRelu< + ElementOutput, + 64 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute, + cutlass::epilogue::thread::ScaleType::NoBetaScaling + >; + + const bool SmemAccumulator = false; + + using B2bGemm = cutlass::gemm::device::B2bGemm< + int8_t, + cutlass::layout::ColumnMajorInterleaved<32>, + int8_t, + cutlass::layout::RowMajorInterleaved<32>, + ElementOutput, + cutlass::layout::ColumnMajorInterleaved<32>, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + ThreadblockShape0, + ThreadblockShape1, + WarpShape0, + WarpShape1, + InstructionShape, + EpilogueOutputOp0, + EpilogueOutputOp1, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + 3, + SmemAccumulator, + 16, + 16, + cutlass::arch::OpMultiplyAddSaturate + >; + + B2bInterleavedFusedGemmRun fusedGemm; + + int batch_count = 2; + int64_t batch_stride_A0 = gemm_s8_sm80_problem_size_0.m() * gemm_s8_sm80_problem_size_0.k(); + int64_t batch_stride_B0 = gemm_s8_sm80_problem_size_1.k() * gemm_s8_sm80_problem_size_1.n(); + int64_t batch_stride_C0 = gemm_s8_sm80_problem_size_0.m() * gemm_s8_sm80_problem_size_0.n(); + int64_t batch_stride_B1 = gemm_s8_sm80_problem_size_1.k() * gemm_s8_sm80_problem_size_1.n(); + int64_t batch_stride_C1 = gemm_s8_sm80_problem_size_1.n(); + int64_t batch_stride_D1 = gemm_s8_sm80_problem_size_1.m() * gemm_s8_sm80_problem_size_1.n(); + int64_t batch_stride_Bias0 = gemm_s8_sm80_problem_size_0.n(); + int64_t batch_stride_Scale0 = 0; + + std::cout << "Running Fused back-to-back INT8 NT interleaved Batched GEMMs with RF residency...\n"; + bool passed = fusedGemm.run( + gemm_s8_sm80_problem_size_0, + gemm_s8_sm80_problem_size_1, + alpha0, + beta0, + alpha1, + beta1, + cutlass::gemm::GemmUniversalMode::kBatched, + batch_count, + batch_stride_A0, + batch_stride_B0, + batch_stride_C0, + batch_stride_B1, + batch_stride_C1, + batch_stride_D1, + batch_stride_Bias0, + batch_stride_Scale0 + ); + + if(passed) + std::cout << "Pass\n"; + else + std::cout << "Fail\n"; + + return passed; +} int main() { std::vectorfuncs = { &run_nonfused_gemm_s8_sm80, - &run_fused_gemm_s8_sm80_rf_res + &run_fused_gemm_s8_sm80_rf_res, + &run_fused_gemm_s8_sm80_rf_res_batch }; return testRun(80, funcs, "gemm int8 RF residency"); - - } - //////////////////////////////////////////////////////////////////////////////// diff --git a/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm80_shmem.cu b/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm80_shmem.cu index 7afe440941..84354221cf 100644 --- a/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm80_shmem.cu +++ b/examples/13_two_tensor_op_fusion/fused_two_gemms_s8_sm80_shmem.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -151,7 +151,7 @@ bool run_fused_gemm_s8_sm80_shmem() { using WarpShape1 = cutlass::gemm::GemmShape<64, 64, 64>; using InstructionShape = cutlass::gemm::GemmShape<16, 8, 32>; - using EpilogueOutputOp0 = + using EpilogueOutputOp0 = cutlass::epilogue::thread::LinearCombinationRelu< ElementOutput, 8 * InstructionShape::kN / 32, @@ -160,7 +160,7 @@ bool run_fused_gemm_s8_sm80_shmem() { cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling >; - using EpilogueOutputOp1 = + using EpilogueOutputOp1 = cutlass::epilogue::thread::LinearCombinationRelu< ElementOutput, 64 / cutlass::sizeof_bits::value, @@ -168,7 +168,7 @@ bool run_fused_gemm_s8_sm80_shmem() { ElementCompute, cutlass::epilogue::thread::ScaleType::NoBetaScaling >; - + const bool SmemAccumulator = true; using B2bGemm = cutlass::gemm::device::B2bGemm< @@ -193,7 +193,6 @@ bool run_fused_gemm_s8_sm80_shmem() { SmemAccumulator, 16, 16, - false, cutlass::arch::OpMultiplyAddSaturate >; diff --git a/examples/13_two_tensor_op_fusion/kernel/b2b_gemm.h b/examples/13_two_tensor_op_fusion/kernel/b2b_gemm.h index 306e8cf47e..fca87a1d09 100644 --- a/examples/13_two_tensor_op_fusion/kernel/b2b_gemm.h +++ b/examples/13_two_tensor_op_fusion/kernel/b2b_gemm.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -40,19 +40,66 @@ #include "cutlass/matrix_coord.h" #include "cutlass/semaphore.h" +#include "kernel/b2b_gemm_grouped_problem_visitor.h" +#include "threadblock/grouped_threadblock_swizzle.h" + ///////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass { namespace gemm { namespace kernel { +namespace detail { + +/// Utility struct for returning the type of the problem visitor used by the swizzling function, +/// if it is a grouped swizzling function, or a default visitor. This is used only for defining +/// the parameters of the problem visitor used in GroupedParams. +template < + typename B2bMma_, + typename ThreadblockSwizzle_, + typename Enable = void +> +struct ProblemVisitorOrDefault; + +/// Return a generic problem visitor for GEMM problems +template < + typename B2bMma_, + typename ThreadblockSwizzle_ +> +struct ProblemVisitorOrDefault::value + >::type> { + using value = B2bGemmGroupedProblemVisitor::value>; +}; + +/// Return the problem visitor specified by the swizzling function +template < + typename B2bMma_, + typename ThreadblockSwizzle_ +> +struct ProblemVisitorOrDefault::value + >::type> { + using value = typename ThreadblockSwizzle_::ProblemVisitor; +}; + +} // namespace detail + ///////////////////////////////////////////////////////////////////////////////////////////////// template < - typename B2bMma_, ///! Threadblock-scoped matrix multiply-accumulate + typename B2bMma_, ///! Threadblock-scoped matrix multiply-accumulate typename Epilogue_, ///! Epilogue - typename ThreadblockSwizzle_, ///! Threadblock swizzling function - bool SplitKSerial ///! If true, code supporting split-K via serial reduction is enabled. + typename ThreadblockSwizzle_ ///! Threadblock swizzling function > struct B2bGemm { @@ -61,50 +108,225 @@ struct B2bGemm { using OutputOp0 = typename B2bMma::OutputOp; using OutputOp1 = typename Epilogue::OutputOp; using ThreadblockSwizzle = ThreadblockSwizzle_; - static bool const kSplitKSerial = SplitKSerial; + + using ElementA0 = typename B2bMma::IteratorA0::Element; + using LayoutA0 = typename B2bMma::IteratorA0::Layout; + using ElementB0 = typename B2bMma::IteratorB0::Element; + using LayoutB0 = typename B2bMma::IteratorB0::Layout; + using ElementB1 = typename B2bMma::IteratorB1::Element; + using LayoutB1 = typename B2bMma::IteratorB1::Layout; + using ElementC = typename Epilogue::OutputTileIterator::Element; + using LayoutC = typename Epilogue::OutputTileIterator::Layout; + + using ScaleBiasData = typename B2bMma::IteratorAccumulatorScaleBias::Element; + + /// Data types needed for higher-level containers. In some cases, a single type must be exposed + /// despite the B2b GEMM using two GEMMs under the hood. In such cases, we select the values from + /// the second GEMM (other than for ElementA/ElementB) + using ElementA = typename B2bMma::IteratorA0::Element; + using LayoutA = typename B2bMma::IteratorA0::Layout; + using ElementB = typename B2bMma::IteratorB0::Element; + using LayoutB = typename B2bMma::IteratorB0::Layout; + + static ComplexTransform const kTransformA = B2bMma::kTransformA; + static ComplexTransform const kTransformB = B2bMma::kTransformB; + using Operator = typename B2bMma::Operator0; + + using OperatorClass = typename Operator::OperatorClass; + using ThreadblockShape = typename B2bMma::Shape0; + using WarpShape = typename Operator::Shape; + using InstructionShape = typename Operator::InstructionShape; + using ArchTag = typename B2bMma::ArchTag; + + static int const kStages = B2bMma::kStages; + static int const kAlignmentA = B2bMma::IteratorA::AccessType::kElements; + static int const kAlignmentB = B2bMma::IteratorB::AccessType::kElements; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + using Mma = B2bMma; + using EpilogueOutputOp = OutputOp1; /// Warp count (concept: GemmShape) using WarpCount0 = typename B2bMma::WarpCount0; static int const kThreadCount = 32 * WarpCount0::kCount; + /// Argument structure + struct Arguments { + + // + // Data members + // + + GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm; + GemmCoord problem_size_0{0,0,0}; + GemmCoord problem_size_1{0,0,0}; + typename B2bMma::IteratorA0::TensorRef ref_A0{}; + typename B2bMma::IteratorB0::TensorRef ref_B0{}; + typename Epilogue::OutputTileIterator::TensorRef ref_C0{}; + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Scale0{}; + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Bias0{}; + typename B2bMma::IteratorB1::TensorRef ref_B1{}; + typename Epilogue::OutputTileIterator::TensorRef ref_C1{}; + typename Epilogue::OutputTileIterator::TensorRef ref_D1{}; + int64_t batch_stride_A0{0}; + int64_t batch_stride_B0{0}; + int64_t batch_stride_B1{0}; + int64_t batch_stride_C1{0}; + int64_t batch_stride_D1{0}; + int64_t batch_stride_Bias0{0}; + int64_t batch_stride_Scale0{0}; + typename OutputOp0::Params epilogue0 {}; + typename OutputOp1::Params epilogue1 {}; + int batch_count{1}; + + // + // Methods + // + + /// Default ctor + Arguments() = default; + + /// Constructs an Arguments structure + CUTLASS_HOST_DEVICE + Arguments( + GemmUniversalMode mode_, + GemmCoord problem_size_0_, + GemmCoord problem_size_1_, + typename B2bMma::IteratorA0::TensorRef ref_A0_, + typename B2bMma::IteratorB0::TensorRef ref_B0_, + typename Epilogue::OutputTileIterator::TensorRef ref_C0_, + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Scale0_, + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Bias0_, + typename B2bMma::IteratorB1::TensorRef ref_B1_, + typename Epilogue::OutputTileIterator::TensorRef ref_C1_, + typename Epilogue::OutputTileIterator::TensorRef ref_D1_, + int64_t batch_stride_A0_, + int64_t batch_stride_B0_, + int64_t batch_stride_B1_, + int64_t batch_stride_C1_, + int64_t batch_stride_D1_, + int64_t batch_stride_Bias0_, + int64_t batch_stride_Scale0_, + typename OutputOp0::Params epilogue0_ = typename OutputOp0::Params(), + typename OutputOp1::Params epilogue1_ = typename OutputOp1::Params(), + int batch_count_ = 1 + ): + mode(mode_), + problem_size_0(problem_size_0_), + problem_size_1(problem_size_1_), + ref_A0(ref_A0_), + ref_B0(ref_B0_), + ref_C0(ref_C0_), + ref_Scale0(ref_Scale0_), + ref_Bias0(ref_Bias0_), + ref_B1(ref_B1_), + ref_C1(ref_C1_), + ref_D1(ref_D1_), + batch_stride_A0(batch_stride_A0_), + batch_stride_B0(batch_stride_B0_), + batch_stride_B1(batch_stride_B1_), + batch_stride_C1(batch_stride_C1_), + batch_stride_D1(batch_stride_D1_), + batch_stride_Bias0(batch_stride_Bias0_), + batch_stride_Scale0(batch_stride_Scale0_), + epilogue0(epilogue0_), + epilogue1(epilogue1_), + batch_count(batch_count_) { + } + }; + + // Arguments structure for grouped B2B problems + struct GroupedArguments { + GemmCoord* problem_size_0; + GemmCoord* problem_size_1; + typename B2bMma::IteratorA0::TensorRef* ref_A0; + typename B2bMma::IteratorB0::TensorRef* ref_B0; + typename Epilogue::OutputTileIterator::TensorRef* ref_C0; + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Scale0; + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Bias0; + typename B2bMma::IteratorB1::TensorRef* ref_B1; + typename Epilogue::OutputTileIterator::TensorRef* ref_C1; + typename Epilogue::OutputTileIterator::TensorRef* ref_D1; + + // Epilogue params remain constant across all problmes in the group. Thus, + // the parameter here is not a pointer. + typename OutputOp0::Params epilogue0; + typename OutputOp1::Params epilogue1; + + int problem_count; + int threadblock_count; + GemmCoord* host_problem_sizes; + + CUTLASS_HOST_DEVICE + GroupedArguments( + int problem_count, + GemmCoord* problem_size_0_, + GemmCoord* problem_size_1_, + typename B2bMma::IteratorA0::TensorRef* ref_A0_, + typename B2bMma::IteratorB0::TensorRef* ref_B0_, + typename Epilogue::OutputTileIterator::TensorRef* ref_C0_, + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Scale0_, + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Bias0_, + typename B2bMma::IteratorB1::TensorRef* ref_B1_, + typename Epilogue::OutputTileIterator::TensorRef* ref_C1_, + typename Epilogue::OutputTileIterator::TensorRef* ref_D1_, + typename OutputOp0::Params epilogue0_ = typename OutputOp0::Params(), + typename OutputOp1::Params epilogue1_ = typename OutputOp1::Params(), + int threadblock_count = 0 + ) : problem_size_0(problem_size_0_), problem_size_1(problem_size_1_), + ref_A0(ref_A0_), ref_B0(ref_B0_), ref_C0(ref_C0_), + ref_Scale0(ref_Scale0_), ref_Bias0(ref_Bias0_), ref_B1(ref_B1_), + ref_C1(ref_C1_), ref_D1(ref_D1_), epilogue0(epilogue0_), epilogue1(epilogue1_), + problem_count(problem_count), + threadblock_count(threadblock_count) + {} + }; + /// Parameters structure struct Params { - cutlass::gemm::GemmCoord problem_size_0; - cutlass::gemm::GemmCoord problem_size_1; - cutlass::gemm::GemmCoord grid_tiled_shape; - int swizzle_log_tile; - typename B2bMma::IteratorA0::Params params_A0; - typename B2bMma::IteratorA0::TensorRef ref_A0; - typename B2bMma::IteratorB0::Params params_B0; - typename B2bMma::IteratorB0::TensorRef ref_B0; - typename Epilogue::OutputTileIterator::Params params_C0; - typename Epilogue::OutputTileIterator::TensorRef ref_C0; - typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Scale0; - typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Bias0; - typename B2bMma::IteratorB1::Params params_B1; - typename B2bMma::IteratorB1::TensorRef ref_B1; - typename Epilogue::OutputTileIterator::Params params_C1; - typename Epilogue::OutputTileIterator::TensorRef ref_C1; - typename Epilogue::OutputTileIterator::Params params_D1; - typename Epilogue::OutputTileIterator::TensorRef ref_D1; - typename OutputOp0::Params output_op_0; - typename OutputOp1::Params output_op_1; - int *semaphore; - int gemm_k_iterations_0; - int gemm_k_size_0; - int gemm_k_iterations_1; - int gemm_k_size_1; + cutlass::gemm::GemmUniversalMode mode = cutlass::gemm::GemmUniversalMode::kGemm; + cutlass::gemm::GemmCoord problem_size_0{}; + cutlass::gemm::GemmCoord problem_size_1{}; + cutlass::gemm::GemmCoord grid_tiled_shape{}; + int swizzle_log_tile{0}; + typename B2bMma::IteratorA0::Params params_A0{}; + typename B2bMma::IteratorA0::TensorRef ref_A0{}; + typename B2bMma::IteratorB0::Params params_B0{}; + typename B2bMma::IteratorB0::TensorRef ref_B0{}; + typename Epilogue::OutputTileIterator::Params params_C0{}; + typename Epilogue::OutputTileIterator::TensorRef ref_C0{}; + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Scale0{}; + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef ref_Bias0{}; + typename B2bMma::IteratorB1::Params params_B1{}; + typename B2bMma::IteratorB1::TensorRef ref_B1{}; + typename Epilogue::OutputTileIterator::Params params_C1{}; + typename Epilogue::OutputTileIterator::TensorRef ref_C1{}; + typename Epilogue::OutputTileIterator::Params params_D1{}; + typename Epilogue::OutputTileIterator::TensorRef ref_D1{}; + typename OutputOp0::Params output_op_0{}; + typename OutputOp1::Params output_op_1{}; + int64_t batch_stride_A0{0}; + int64_t batch_stride_B0{0}; + int64_t batch_stride_B1{0}; + int64_t batch_stride_C1{0}; + int64_t batch_stride_D1{0}; + int64_t batch_stride_Bias0{0}; + int64_t batch_stride_Scale0{0}; + int *semaphore = nullptr; + int gemm_k_iterations_0{0}; + int gemm_k_size_0{0}; + int gemm_k_iterations_1{0}; + int gemm_k_size_1{0}; // // Methods // - CUTLASS_HOST_DEVICE - Params(): swizzle_log_tile(0), semaphore(0), gemm_k_iterations_0(0), gemm_k_size_0(0), - gemm_k_iterations_1(0), gemm_k_size_1(0) { } + Params() = default; CUTLASS_HOST_DEVICE Params( + cutlass::gemm::GemmUniversalMode mode, cutlass::gemm::GemmCoord const & problem_size_0, cutlass::gemm::GemmCoord const & problem_size_1, cutlass::gemm::GemmCoord const & grid_tiled_shape, @@ -116,14 +338,22 @@ struct B2bGemm { typename B2bMma::IteratorB1::TensorRef ref_B1, typename Epilogue::OutputTileIterator::TensorRef ref_C1, typename Epilogue::OutputTileIterator::TensorRef ref_D1, + int64_t batch_stride_A0, + int64_t batch_stride_B0, + int64_t batch_stride_B1, + int64_t batch_stride_C1, + int64_t batch_stride_D1, + int64_t batch_stride_Bias0, + int64_t batch_stride_Scale0, typename OutputOp0::Params output_op_0 = typename OutputOp0::Params(), typename OutputOp1::Params output_op_1 = typename OutputOp1::Params(), int *workspace = nullptr ): + mode(mode), problem_size_0(problem_size_0), problem_size_1(problem_size_1), grid_tiled_shape(grid_tiled_shape), - swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), + swizzle_log_tile(ThreadblockSwizzle::get_log_tile(grid_tiled_shape)), params_A0(ref_A0.layout()), ref_A0(ref_A0), params_B0(ref_B0.layout()), @@ -138,6 +368,13 @@ struct B2bGemm { ref_C1(ref_C1), params_D1(ref_D1.layout()), ref_D1(ref_D1), + batch_stride_A0(batch_stride_A0), + batch_stride_B0(batch_stride_B0), + batch_stride_B1(batch_stride_B1), + batch_stride_C1(batch_stride_C1), + batch_stride_D1(batch_stride_D1), + batch_stride_Bias0(batch_stride_Bias0), + batch_stride_Scale0(batch_stride_Scale0), output_op_0(output_op_0), output_op_1(output_op_1) { @@ -152,6 +389,81 @@ struct B2bGemm { } }; + struct GroupedParams { + cutlass::gemm::GemmCoord* problem_size_0; + cutlass::gemm::GemmCoord* problem_size_1; + cutlass::gemm::GemmCoord* grid_tiled_shape; + typename B2bMma::IteratorA0::TensorRef* ref_A0; + typename B2bMma::IteratorB0::TensorRef* ref_B0; + typename Epilogue::OutputTileIterator::TensorRef* ref_C0; + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Scale0; + typename B2bMma::IteratorAccumulatorScaleBias::TensorRef* ref_Bias0; + typename B2bMma::IteratorB1::TensorRef* ref_B1; + typename Epilogue::OutputTileIterator::TensorRef* ref_C1; + typename Epilogue::OutputTileIterator::TensorRef* ref_D1; + + // Epilogue params remain constant across all problmes in the group. Thus, + // the parameter here is not a pointer. + typename OutputOp0::Params output_op_0; + typename OutputOp1::Params output_op_1; + + using ProblemVisitor = typename detail::ProblemVisitorOrDefault::value; + typename ProblemVisitor::Params problem_visitor; + int threadblock_count; + int* workspace; + + CUTLASS_HOST_DEVICE + GroupedParams() {} + + CUTLASS_HOST_DEVICE + GroupedParams( + GroupedArguments const &args, + void *workspace = nullptr, + int tile_count = 0 + ) : + problem_size_0(args.problem_size_0), problem_size_1(args.problem_size_1), + ref_A0(args.ref_A0), ref_B0(args.ref_B0), ref_C0(args.ref_C0), + ref_Scale0(args.ref_Scale0), ref_Bias0(args.ref_Bias0), ref_B1(args.ref_B1), ref_C1(args.ref_C1), ref_D1(args.ref_D1), + output_op_0(args.epilogue0), output_op_1(args.epilogue1), + problem_visitor(args.problem_size_0, args.problem_size_1, args.problem_count, workspace, tile_count), + threadblock_count(args.threadblock_count), + workspace(reinterpret_cast(workspace)) {} + + CUTLASS_HOST_DEVICE + void transpose() { + // Only row-major outputs are currently supported, so no transpose is performed + } + + /// Returns non-grouped paramaters to be used as input to the kernel-level + /// operator for the problem indicated by problem_visitor. + CUTLASS_HOST_DEVICE + Params to_single_params(const ProblemVisitor& problem_visitor) const { + GemmCoord problem_size0 = problem_visitor.problem_size0(); + GemmCoord problem_size1 = problem_visitor.problem_size1(); + int32_t idx = problem_visitor.problem_index(); + GemmCoord grid_shape = problem_visitor.grid_shape(problem_size1); + + return Params( + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size0, + problem_size1, + grid_shape, + ref_A0[idx], + ref_B0[idx], + ref_C0[idx], + ref_Scale0[idx], + ref_Bias0[idx], + ref_B1[idx], + ref_C1[idx], + ref_D1[idx], + 0, 0, 0, 0, 0, 0, 0, // Batched B2B GEMMs within the grouped kernel are currently unsupported + output_op_0, + output_op_1, + workspace + ); + } + }; + /// Shared memory storage structure union SharedStorage { typename B2bMma::B2bMmaSharedStorage main_loop; @@ -163,7 +475,7 @@ struct B2bGemm { // CUTLASS_HOST_DEVICE - B2bGemm() { } + B2bGemm() { } /// Determines whether kernel satisfies alignment static Status can_implement( @@ -223,7 +535,7 @@ struct B2bGemm { if(problem_size_0.n() > B2bMma::Shape0::kN) return Status::kErrorInvalidProblem; - + if(problem_size_1.n() > B2bMma::Shape1::kN) return Status::kErrorInvalidProblem; @@ -233,9 +545,13 @@ struct B2bGemm { /// Executes one GEMM CUTLASS_DEVICE void operator()(Params const ¶ms, SharedStorage &shared_storage) { - - // Compute threadblock location ThreadblockSwizzle threadblock_swizzle; + run_with_swizzle(params, shared_storage, threadblock_swizzle); + } + + /// Executes one GEMM with an externally-provided swizzling function + CUTLASS_DEVICE + void run_with_swizzle(Params const ¶ms, SharedStorage &shared_storage, ThreadblockSwizzle& threadblock_swizzle) { cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); @@ -247,37 +563,64 @@ struct B2bGemm { return; } + ElementA0 *ptr_A0 = static_cast(params.ref_A0.data()); + ElementB0 *ptr_B0 = static_cast(params.ref_B0.data()); + ElementB1 *ptr_B1 = static_cast(params.ref_B1.data()); + + ScaleBiasData *ptr_Bias0 = static_cast(params.ref_Bias0.data()); + ScaleBiasData *ptr_Scale0 = static_cast(params.ref_Scale0.data()); + + int offset_k_0 = 0; + int offset_k_1 = 0; + + int problem_size_k_0 = params.problem_size_0.k(); + int problem_size_k_1 = params.problem_size_1.k(); + + if (params.mode == GemmUniversalMode::kGemm) { + + // Problem size is a function of threadblock index in the K dimension + problem_size_k_0 = min( + problem_size_k_0, + (threadblock_tile_offset.k() + 1) * params.gemm_k_size_0); + + // Problem size is a function of threadblock index in the K dimension + problem_size_k_1 = min( + problem_size_k_1, + (threadblock_tile_offset.k() + 1) * params.gemm_k_size_1); + + offset_k_0 = threadblock_tile_offset.k() * params.gemm_k_size_0; + offset_k_1 = threadblock_tile_offset.k() * params.gemm_k_size_1; + } + + else if (params.mode == GemmUniversalMode::kBatched) { + ptr_A0 += threadblock_tile_offset.k() * params.batch_stride_A0; + ptr_B0 += threadblock_tile_offset.k() * params.batch_stride_B0; + ptr_B1 += threadblock_tile_offset.k() * params.batch_stride_B1; + ptr_Bias0 += threadblock_tile_offset.k() * params.batch_stride_Bias0; + ptr_Scale0 += threadblock_tile_offset.k() * params.batch_stride_Scale0; + } + // Compute initial location in logical coordinates cutlass::MatrixCoord tb_offset_A0{ threadblock_tile_offset.m() * B2bMma::Shape0::kM, - threadblock_tile_offset.k() * params.gemm_k_size_0, + offset_k_0, }; cutlass::MatrixCoord tb_offset_B0{ - threadblock_tile_offset.k() * params.gemm_k_size_0, + offset_k_0, threadblock_tile_offset.n() * B2bMma::Shape0::kN }; cutlass::MatrixCoord tb_offset_B1{ - threadblock_tile_offset.k() * params.gemm_k_size_1, + offset_k_1, threadblock_tile_offset.n() * B2bMma::Shape1::kN }; - // Problem size is a function of threadblock index in the K dimension - int problem_size_k_0 = min( - params.problem_size_0.k(), - (threadblock_tile_offset.k() + 1) * params.gemm_k_size_0); - // Compute threadblock-scoped matrix multiply-add int gemm_k_iterations_0 = (problem_size_k_0 - tb_offset_A0.column() + B2bMma::Shape0::kK - 1) / B2bMma::Shape0::kK; - // Problem size is a function of threadblock index in the K dimension - int problem_size_k_1 = min( - params.problem_size_1.k(), - (threadblock_tile_offset.k() + 1) * params.gemm_k_size_1); - // Compute threadblock-scoped matrix multiply-add -// int gemm_k_iterations_1 = (problem_size_k_1 - tb_offset_B1.row() + B2bMma::Shape1::kK - 1) / B2bMma::Shape1::kK; + // int gemm_k_iterations_1 = (problem_size_k_1 - tb_offset_B1.row() + B2bMma::Shape1::kK - 1) / B2bMma::Shape1::kK; // Compute position within threadblock @@ -286,34 +629,33 @@ struct B2bGemm { // Construct iterators to A and B operands typename B2bMma::IteratorA0 iterator_A0( params.params_A0, - params.ref_A0.data(), + ptr_A0, {params.problem_size_0.m(), problem_size_k_0}, thread_idx, tb_offset_A0); typename B2bMma::IteratorB0 iterator_B0( params.params_B0, - params.ref_B0.data(), + ptr_B0, {problem_size_k_0, params.problem_size_0.n()}, thread_idx, tb_offset_B0); typename B2bMma::IteratorB1 iterator_B1( params.params_B1, - params.ref_B1.data(), + ptr_B1, {problem_size_k_1, params.problem_size_1.n()}, thread_idx, tb_offset_B1); - // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. - int warp_idx = __shfl_sync(0x1f, threadIdx.x / 32, 0); + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); int lane_idx = threadIdx.x % 32; // Construct iterators to accumulator scale/bias vector typename B2bMma::IteratorAccumulatorScaleBias iterator_Scale0( - params.ref_Scale0.data(), + ptr_Scale0, {1, params.problem_size_0.n()}, thread_idx, warp_idx, @@ -323,7 +665,7 @@ struct B2bGemm { ); typename B2bMma::IteratorAccumulatorScaleBias iterator_Bias0( - params.ref_Bias0.data(), + ptr_Bias0, {1, params.problem_size_0.n()}, thread_idx, warp_idx, @@ -332,16 +674,19 @@ struct B2bGemm { ) ); - - // // Main loop // OutputOp0 output_op_0(params.output_op_0); + if (cutlass::gemm::threadblock::detail::IsGroupedSwizzle::value) { + // Wait for all threads to finish their epilogue phases from the previous tile. + __syncthreads(); + } + // Construct thread-scoped matrix multiply - B2bMma b2bMma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + B2bMma b2bMma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx, params.problem_size_0.n()); typename B2bMma::FragmentC0 src_accum; typename B2bMma::FragmentC1 accumulators; @@ -349,11 +694,9 @@ struct B2bGemm { src_accum.clear(); accumulators.clear(); - if (!kSplitKSerial || gemm_k_iterations_0 > 0) { - // Compute threadblock-scoped matrix multiply-add - b2bMma(gemm_k_iterations_0, accumulators, iterator_A0, iterator_B0, - iterator_Scale0, iterator_Bias0, iterator_B1, src_accum, output_op_0); - } + // Compute threadblock-scoped matrix multiply-add + b2bMma(gemm_k_iterations_0, accumulators, iterator_A0, iterator_B0, + iterator_Scale0, iterator_Bias0, iterator_B1, src_accum, output_op_0); // // Epilogue @@ -376,23 +719,32 @@ struct B2bGemm { int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + ElementC *ptr_C1 = static_cast(params.ref_C1.data()); + ElementC *ptr_D1 = static_cast(params.ref_D1.data()); + // Construct the semaphore. Semaphore semaphore(params.semaphore + block_idx, thread_idx); - // If performing a reduction via split-K, fetch the initial synchronization - if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { - - // Fetch the synchronization lock initially but do not block. - semaphore.fetch(); + if (params.mode == GemmUniversalMode::kGemm) { + // If performing a reduction via split-K, fetch the initial synchronization + + if (params.grid_tiled_shape.k() > 1) { + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); - // Indicate which position in a serial reduction the output operator is currently updating - output_op_1.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + // Indicate which position in a serial reduction the output operator is currently updating + output_op_1.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + } + } + else if (params.mode == GemmUniversalMode::kBatched) { + ptr_C1 += threadblock_tile_offset.k() * params.batch_stride_C1; + ptr_D1 += threadblock_tile_offset.k() * params.batch_stride_D1; } // Tile iterator loading from source tensor. typename Epilogue::OutputTileIterator iterator_C1( params.params_C1, - params.ref_C1.data(), + ptr_C1, params.problem_size_1.mn(), thread_idx, threadblock_offset @@ -401,21 +753,21 @@ struct B2bGemm { // Tile iterator writing to destination tensor. typename Epilogue::OutputTileIterator iterator_D1( params.params_D1, - params.ref_D1.data(), + ptr_D1, params.problem_size_1.mn(), thread_idx, threadblock_offset ); Epilogue epilogue( - shared_storage.epilogue, - thread_idx, - warp_idx, + shared_storage.epilogue, + thread_idx, + warp_idx, lane_idx); // Wait on the semaphore - this latency may have been covered by iterator construction - if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { - + if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { + // For subsequent threadblocks, the source matrix is held in the 'D' tensor. if (threadblock_tile_offset.k()) { iterator_C1 = iterator_D1; @@ -427,14 +779,14 @@ struct B2bGemm { } // Execute the epilogue operator to update the destination tensor. - epilogue(output_op_1, iterator_D1, accumulators, iterator_C1); - + epilogue(output_op_1, iterator_D1, accumulators, iterator_C1); + // // Release the semaphore // - if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { - + if (params.mode == GemmUniversalMode::kGemm && params.grid_tiled_shape.k() > 1) { + int lock = 0; if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { @@ -457,4 +809,3 @@ struct B2bGemm { } // namespace kernel } // namespace gemm } // namespace cutlass - diff --git a/examples/13_two_tensor_op_fusion/kernel/b2b_gemm_grouped_problem_visitor.h b/examples/13_two_tensor_op_fusion/kernel/b2b_gemm_grouped_problem_visitor.h new file mode 100644 index 0000000000..13faadf039 --- /dev/null +++ b/examples/13_two_tensor_op_fusion/kernel/b2b_gemm_grouped_problem_visitor.h @@ -0,0 +1,157 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Scheduler for grouped B2b GEMMs +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/gemm/kernel/grouped_problem_visitor.h" +#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Visitor class to abstract away the algorithm for iterating over tiles +template +struct B2bGemmGroupedProblemVisitor : public GroupedProblemVisitor< + detail::GemmGroupedProblemSizeHelper, + ThreadblockShape, + GroupScheduleMode_, + PrefetchTileCount, + ThreadCount> { + + using ProblemSizeHelper = detail::GemmGroupedProblemSizeHelper; + using Base = GroupedProblemVisitor; + using BaseParams = typename Base::Params; + using SharedStorage = typename Base::SharedStorage; + static bool const kTransposed = Transposed; + + cutlass::gemm::GemmCoord const *problem_sizes0; + cutlass::gemm::GemmCoord const *problem_sizes1; + + struct Params { + cutlass::gemm::GemmCoord const *problem_sizes0; + cutlass::gemm::GemmCoord const *problem_sizes1; + int32_t problem_count; + void const *workspace; + int32_t tile_count; + + // + // Methods + // + + /// Ctor + CUTLASS_HOST_DEVICE + Params(): problem_sizes0(nullptr), problem_sizes1(nullptr), + problem_count(0), workspace(nullptr), tile_count(0) { } + + /// Ctor + CUTLASS_HOST_DEVICE + Params( + cutlass::gemm::GemmCoord const *problem_sizes0, + cutlass::gemm::GemmCoord const *problem_sizes1, + int32_t problem_count, + void const *workspace = nullptr, + int32_t tile_count = 0 + ): + problem_sizes0(problem_sizes0), + problem_sizes1(problem_sizes1), + problem_count(problem_count), + workspace(workspace), + tile_count(tile_count) + {} + + /// Convert the B2b-GEMM-specific parameters to those used by the base class + CUTLASS_HOST_DEVICE + BaseParams to_base() const { + return BaseParams(// Set problem_sizes as problem_sizes0 because these determine + // shape of the grid used in the non-grouped B2b GEMM + problem_sizes0, + problem_count, + workspace, + tile_count); + } + + }; + + // + // Methods + // + CUTLASS_DEVICE + B2bGemmGroupedProblemVisitor( + Params const ¶ms_, + SharedStorage &shared_storage_, + int32_t block_idx + ): Base ( + params_.to_base(), + shared_storage_, block_idx), + problem_sizes0(params_.problem_sizes0), + problem_sizes1(params_.problem_sizes1) + {} + + /// Returns the problem size 0 for the current problem + CUTLASS_HOST_DEVICE + cutlass::gemm::GemmCoord problem_size0() const { + GemmCoord problem = problem_sizes0[this->problem_idx]; + ProblemSizeHelper::possibly_transpose_problem(problem); + return problem; + } + + /// Returns the problem size 1 for the current problem + CUTLASS_HOST_DEVICE + cutlass::gemm::GemmCoord problem_size1() const { + GemmCoord problem = problem_sizes1[this->problem_idx]; + ProblemSizeHelper::possibly_transpose_problem(problem); + return problem; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/13_two_tensor_op_fusion/kernel/b2b_implicit_gemm_convolution.h b/examples/13_two_tensor_op_fusion/kernel/b2b_implicit_gemm_convolution.h index 8ebc10c0a2..d249a2c237 100644 --- a/examples/13_two_tensor_op_fusion/kernel/b2b_implicit_gemm_convolution.h +++ b/examples/13_two_tensor_op_fusion/kernel/b2b_implicit_gemm_convolution.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop.h b/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop.h index da7c9aef68..1b604c040b 100644 --- a/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop.h +++ b/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_sm75.h b/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_sm75.h index bd07b8b954..0168637b6d 100644 --- a/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_sm75.h +++ b/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_sm75.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_sm80.h b/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_sm80.h index ab45a7a411..d76fe8125d 100644 --- a/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_sm80.h +++ b/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_sm80.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_smem_accumulator_sm75.h b/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_smem_accumulator_sm75.h index 454f13ffd4..462ad1efd7 100644 --- a/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_smem_accumulator_sm75.h +++ b/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_smem_accumulator_sm75.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_smem_accumulator_sm80.h b/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_smem_accumulator_sm80.h index 55619134d9..e953567610 100644 --- a/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_smem_accumulator_sm80.h +++ b/examples/13_two_tensor_op_fusion/kernel/default_b2b_conv2d_fprop_smem_accumulator_sm80.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/examples/13_two_tensor_op_fusion/kernel/default_b2b_gemm.h b/examples/13_two_tensor_op_fusion/kernel/default_b2b_gemm.h index 4d807b9dc4..2ad3d7f386 100644 --- a/examples/13_two_tensor_op_fusion/kernel/default_b2b_gemm.h +++ b/examples/13_two_tensor_op_fusion/kernel/default_b2b_gemm.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -30,10 +30,10 @@ **************************************************************************************************/ /*! \file - \brief + \brief Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with the appropriate threadblock-scoped epilogue. - + Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are accommodated by exchanging A and B operands and assuming transposed layouts. Partial specializations here choose 'device::GemmTransposed' to implement this functionality. @@ -63,7 +63,9 @@ #include "cutlass/transform/threadblock/predicated_tile_iterator.h" #include "kernel/b2b_gemm.h" +#include "kernel/grouped.h" #include "threadblock/default_b2b_mma.h" +#include "threadblock/grouped_threadblock_swizzle.h" //////////////////////////////////////////////////////////////////////////////// @@ -73,6 +75,9 @@ namespace kernel { //////////////////////////////////////////////////////////////////////////////// +template +using IsGroupedSwizzle = cutlass::gemm::threadblock::detail::IsGroupedSwizzle; + template < /// Element type for A matrix operand typename ElementA_, @@ -114,12 +119,12 @@ template < typename ThreadblockSwizzle, /// Number of stages used in the pipelined mainloop int Stages, - /// If true, kernel is configured to support serial reduction in the epilogue - bool SplitKSerial, /// Operation performed by GEMM typename Operator, /// Stage accumulator in shared memory - bool SmemAccumulator = false + bool SmemAccumulator = false, + /// Whether or not the operation is grouped + typename Enable = void > struct DefaultB2bGemm; @@ -161,17 +166,77 @@ template < typename ThreadblockSwizzle, /// Number of stages used in the pipelined mainloop int Stages, - /// If true, kernel is configured to support serial reduction in the - /// epilogue - bool SplitKSerial, /// Operation performed by GEMM typename Operator> struct DefaultB2bGemm { + EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages, + Operator, false, typename platform::enable_if::value>::type> { + /// Define the threadblock-scoped matrix multiply-accumulate + using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma< + ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, + ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80, + ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, + InstructionShape, Stages, Operator, EpilogueOutputOp0>::ThreadblockB2bMma; + + static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK; + + /// Define the epilogue + using Epilogue = + typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape1, typename B2bMma::Operator1, kPartitionsK1, EpilogueOutputOp1, + EpilogueOutputOp1::kCount>::Epilogue; + + /// Define the kernel-level GEMM operator. + using B2bGemmKernel = kernel::B2bGemm; +}; + +/// Partial specialization for Ampere Architecture with grouped operation +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of A matrix in units of elements + int kAlignmentB, + /// Element type for C and D matrix operands + typename ElementC, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape0, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape1, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape0, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape1, + /// Warp-level tile size (concept: GemmShape) + typename InstructionShape, + /// Epilogue output operator + typename EpilogueOutputOp0, + /// Epilogue output operator + typename EpilogueOutputOp1, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation performed by GEMM + typename Operator> +struct DefaultB2bGemm::value>::type> { /// Define the threadblock-scoped matrix multiply-accumulate using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, @@ -188,7 +253,9 @@ struct DefaultB2bGemm::Epilogue; /// Define the kernel-level GEMM operator. - using B2bGemmKernel = kernel::B2bGemm; + using UnderlyingB2bGemmKernel = kernel::B2bGemm; + + using B2bGemmKernel = kernel::GroupedKernel; }; @@ -228,8 +295,6 @@ template < typename EpilogueOutputOp1, /// Threadblock-level swizzling operator typename ThreadblockSwizzle, - /// If true, kernel is configured to support serial reduction in the epilogue - bool SplitKSerial, /// Operation performed by GEMM typename Operator > @@ -249,8 +314,9 @@ struct DefaultB2bGemm< EpilogueOutputOp1, ThreadblockSwizzle, 2, - SplitKSerial, - Operator + Operator, + false, + typename platform::enable_if::value>::type > { /// Define the threadblock-scoped matrix multiply-accumulate @@ -274,7 +340,7 @@ struct DefaultB2bGemm< Operator, EpilogueOutputOp0 >::ThreadblockB2bMma; - + static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK; /// Define the epilogue @@ -287,7 +353,7 @@ struct DefaultB2bGemm< >::Epilogue; /// Define the kernel-level GEMM operator. - using B2bGemmKernel = kernel::B2bGemm; + using B2bGemmKernel = kernel::B2bGemm; }; @@ -323,20 +389,17 @@ template < int Stages, /// Number of Interleaved k int InterleavedK, - /// If true, kernel is configured to support serial reduction in the - /// epilogue - bool SplitKSerial, /// Operation performed by GEMM typename Operator> struct DefaultB2bGemm< ElementA, layout::ColumnMajorInterleaved, kAlignmentA, - ElementB, layout::RowMajorInterleaved, kAlignmentB, + ElementB, layout::RowMajorInterleaved, kAlignmentB, ElementC, layout::ColumnMajorInterleaved, int32_t, arch::OpClassTensorOp, arch::Sm80, ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages, - SplitKSerial, Operator> { + Operator, false, typename platform::enable_if::value>::type> { using LayoutA = layout::ColumnMajorInterleaved; using LayoutB = layout::RowMajorInterleaved; using LayoutC = layout::ColumnMajorInterleaved; @@ -360,7 +423,7 @@ struct DefaultB2bGemm< 64 / sizeof_bits::value, InterleavedK>::Epilogue; /// Define the kernel-level GEMM operator. - using B2bGemmKernel = kernel::B2bGemm; + using B2bGemmKernel = kernel::B2bGemm; }; //////////////////////////////////////////////////////////////////////////////// @@ -396,19 +459,17 @@ template < typename ThreadblockSwizzle, /// Number of Interleaved k int InterleavedK, - /// If true, kernel is configured to support serial reduction in the - /// epilogue - bool SplitKSerial, /// Operation performed by GEMM typename Operator> struct DefaultB2bGemm, kAlignmentA, ElementB, layout::RowMajorInterleaved, kAlignmentB, ElementC, layout::ColumnMajorInterleaved, - int32_t, arch::OpClassTensorOp, arch::Sm75, + int32_t, arch::OpClassTensorOp, arch::Sm75, ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1, - ThreadblockSwizzle, 2, SplitKSerial, Operator> { + ThreadblockSwizzle, 2, Operator, false, + typename platform::enable_if::value>::type> { using LayoutA = layout::ColumnMajorInterleaved; using LayoutB = layout::RowMajorInterleaved; using LayoutC = layout::ColumnMajorInterleaved; @@ -418,7 +479,7 @@ struct DefaultB2bGemm, /// Define the threadblock-scoped matrix multiply-accumulate using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, LayoutC, - arch::OpClassTensorOp, arch::Sm75, ThreadblockShape0, ThreadblockShape1, + arch::OpClassTensorOp, arch::Sm75, ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, InstructionShape, 2, Operator, EpilogueOutputOp0, true>::ThreadblockB2bMma; static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK; @@ -430,7 +491,7 @@ struct DefaultB2bGemm, 64 / sizeof_bits::value, InterleavedK>::Epilogue; /// Define the kernel-level GEMM operator. - using B2bGemmKernel = kernel::B2bGemm; + using B2bGemmKernel = kernel::B2bGemm; }; //////////////////////////////////////////////////////////////////////////////// diff --git a/examples/13_two_tensor_op_fusion/kernel/default_b2b_gemm_smem_accumulator.h b/examples/13_two_tensor_op_fusion/kernel/default_b2b_gemm_smem_accumulator.h index fcff8672bf..ad548bc98b 100644 --- a/examples/13_two_tensor_op_fusion/kernel/default_b2b_gemm_smem_accumulator.h +++ b/examples/13_two_tensor_op_fusion/kernel/default_b2b_gemm_smem_accumulator.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -30,10 +30,10 @@ **************************************************************************************************/ /*! \file - \brief + \brief Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with the appropriate threadblock-scoped epilogue. - + Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are accommodated by exchanging A and B operands and assuming transposed layouts. Partial specializations here choose 'device::GemmTransposed' to implement this functionality. @@ -112,22 +112,19 @@ template < typename ThreadblockSwizzle, /// Number of stages used in the pipelined mainloop int Stages, - /// If true, kernel is configured to support serial reduction in the - /// epilogue - bool SplitKSerial, /// Operation performed by GEMM typename Operator> struct DefaultB2bGemm { /// Define the threadblock-scoped matrix multiply-accumulate using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma< ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, arch::Sm80, - ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, + ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, InstructionShape, Stages, Operator, EpilogueOutputOp0, false, true>::ThreadblockB2bMma; static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK; @@ -139,10 +136,9 @@ struct DefaultB2bGemm::Epilogue; /// Define the kernel-level GEMM operator. - using B2bGemmKernel = kernel::B2bGemm; + using B2bGemmKernel = kernel::B2bGemm; }; - //////////////////////////////////////////////////////////////////////////////// /// Partial specialization for Turing Architecture @@ -179,8 +175,6 @@ template < typename EpilogueOutputOp1, /// Threadblock-level swizzling operator typename ThreadblockSwizzle, - /// If true, kernel is configured to support serial reduction in the epilogue - bool SplitKSerial, /// Operation performed by GEMM typename Operator > @@ -200,7 +194,6 @@ struct DefaultB2bGemm< EpilogueOutputOp1, ThreadblockSwizzle, 2, - SplitKSerial, Operator, true > { @@ -228,7 +221,7 @@ struct DefaultB2bGemm< false, true >::ThreadblockB2bMma; - + static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK; /// Define the epilogue @@ -241,7 +234,7 @@ struct DefaultB2bGemm< >::Epilogue; /// Define the kernel-level GEMM operator. - using B2bGemmKernel = kernel::B2bGemm; + using B2bGemmKernel = kernel::B2bGemm; }; @@ -277,20 +270,17 @@ template < int Stages, /// Number of Interleaved k int InterleavedK, - /// If true, kernel is configured to support serial reduction in the - /// epilogue - bool SplitKSerial, /// Operation performed by GEMM typename Operator> struct DefaultB2bGemm< ElementA, layout::ColumnMajorInterleaved, kAlignmentA, - ElementB, layout::RowMajorInterleaved, kAlignmentB, + ElementB, layout::RowMajorInterleaved, kAlignmentB, ElementC, layout::ColumnMajorInterleaved, int32_t, arch::OpClassTensorOp, arch::Sm80, ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1, ThreadblockSwizzle, Stages, - SplitKSerial, Operator, true> { + Operator, true> { using LayoutA = layout::ColumnMajorInterleaved; using LayoutB = layout::RowMajorInterleaved; using LayoutC = layout::ColumnMajorInterleaved; @@ -314,7 +304,7 @@ struct DefaultB2bGemm< 64 / sizeof_bits::value, InterleavedK>::Epilogue; /// Define the kernel-level GEMM operator. - using B2bGemmKernel = kernel::B2bGemm; + using B2bGemmKernel = kernel::B2bGemm; }; //////////////////////////////////////////////////////////////////////////////// @@ -350,19 +340,16 @@ template < typename ThreadblockSwizzle, /// Number of Interleaved k int InterleavedK, - /// If true, kernel is configured to support serial reduction in the - /// epilogue - bool SplitKSerial, /// Operation performed by GEMM typename Operator> struct DefaultB2bGemm, kAlignmentA, ElementB, layout::RowMajorInterleaved, kAlignmentB, ElementC, layout::ColumnMajorInterleaved, - int32_t, arch::OpClassTensorOp, arch::Sm75, + int32_t, arch::OpClassTensorOp, arch::Sm75, ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, InstructionShape, EpilogueOutputOp0, EpilogueOutputOp1, - ThreadblockSwizzle, 2, SplitKSerial, Operator, true> { + ThreadblockSwizzle, 2, Operator, true> { using LayoutA = layout::ColumnMajorInterleaved; using LayoutB = layout::RowMajorInterleaved; using LayoutC = layout::ColumnMajorInterleaved; @@ -371,9 +358,9 @@ struct DefaultB2bGemm, /// Define the threadblock-scoped matrix multiply-accumulate using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma< - ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, - ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm75, - ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, + ElementA, LayoutA, kAlignmentA, ElementB, LayoutB, kAlignmentB, + ElementAccumulator, LayoutC, arch::OpClassTensorOp, arch::Sm75, + ThreadblockShape0, ThreadblockShape1, WarpShape0, WarpShape1, InstructionShape, 2, Operator, EpilogueOutputOp0, true, true>::ThreadblockB2bMma; static const int kPartitionsK1 = ThreadblockShape1::kK / WarpShape1::kK; @@ -385,7 +372,7 @@ struct DefaultB2bGemm, 64 / sizeof_bits::value, InterleavedK>::Epilogue; /// Define the kernel-level GEMM operator. - using B2bGemmKernel = kernel::B2bGemm; + using B2bGemmKernel = kernel::B2bGemm; }; //////////////////////////////////////////////////////////////////////////////// diff --git a/examples/13_two_tensor_op_fusion/kernel/grouped.h b/examples/13_two_tensor_op_fusion/kernel/grouped.h new file mode 100644 index 0000000000..2698a281a1 --- /dev/null +++ b/examples/13_two_tensor_op_fusion/kernel/grouped.h @@ -0,0 +1,168 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief High-level interface for running a grouped version of a CUTLASS kernel +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/complex.h" +#include "cutlass/semaphore.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/trace.h" +#include "cutlass/gemm/kernel/gemm_transpose_operands.h" +#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// High-level interface for running a grouped version of a CUTLASS kernel +template < + typename BaseKernel_ ///! Kernel-scoped matrix multiply-accumulate +> +struct GroupedKernel { +public: + + using BaseKernel = BaseKernel_; + using Epilogue = typename BaseKernel::Epilogue; + + /// Types that need to be exported to work properly with device::BaseGrouped + using ElementA = typename BaseKernel::ElementA; + using LayoutA = typename BaseKernel::LayoutA; + using TensorRefA = TensorRef; + static ComplexTransform const kTransformA = BaseKernel::kTransformA; + static int const kAlignmentA = BaseKernel::kAlignmentA; + + using ElementB = typename BaseKernel::ElementB; + using LayoutB = typename BaseKernel::LayoutB; + using TensorRefB = TensorRef; + static ComplexTransform const kTransformB = BaseKernel::kTransformB; + static int const kAlignmentB = BaseKernel::kAlignmentB; + + using ElementC = typename BaseKernel::ElementC; + using LayoutC = typename BaseKernel::LayoutC; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + static int const kAlignmentC = BaseKernel::kAlignmentC; + + using ElementAccumulator = typename BaseKernel::Mma::Policy::Operator::ElementC; + + using EpilogueOutputOp = typename BaseKernel::EpilogueOutputOp; + using ThreadblockSwizzle = typename BaseKernel::ThreadblockSwizzle; + + using Operator = typename BaseKernel::Operator; + using WarpMmaOperator = typename BaseKernel::Mma::Policy::Operator; + + using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; + using MathOperator = typename WarpMmaOperator::MathOperator; + using OperatorClass = typename WarpMmaOperator::OperatorClass; + using ArchTag = typename WarpMmaOperator::ArchTag; + using ThreadblockShape = typename BaseKernel::Mma::Shape; + using WarpShape = typename BaseKernel::WarpShape; + using InstructionShape = typename BaseKernel::InstructionShape; + static int const kStages = BaseKernel::Mma::kStages; + + using Mma = typename BaseKernel::Mma; + + using Arguments = typename BaseKernel::GroupedArguments; + using Params = typename BaseKernel::GroupedParams; + using ProblemVisitor = typename ThreadblockSwizzle::ProblemVisitor; + + static int const kThreadCount = BaseKernel::kThreadCount; + + /// Shared memory storage structure + struct SharedStorage { + typename BaseKernel::SharedStorage kernel; + + // ProblemVisitor shared storage can't be overlapped with others + typename ProblemVisitor::SharedStorage problem_visitor; + }; + +public: + + // + // Methods + // + + CUTLASS_DEVICE + GroupedKernel() { } + + /// Determines whether kernel satisfies alignment + static Status can_implement(cutlass::gemm::GemmCoord const & problem_size) { + return Status::kSuccess; + } + + static Status can_implement(Arguments const &args) { + return Status::kSuccess; + } + + /// Executes a kernel-level GEMM in a loop + CUTLASS_DEVICE + void operator()(Params ¶ms, SharedStorage &shared_storage) { + + ThreadblockSwizzle swizzle(params.problem_visitor, shared_storage.problem_visitor, blockIdx.x); + + if (ProblemVisitor::kTransposed) { + params.transpose(); + } + + BaseKernel mma; + + // Outer 'persistent' loop to iterate over tiles + while (swizzle.problem_visitor.next_tile()) { + + typename BaseKernel::Params mma_params = params.to_single_params(swizzle.problem_visitor); + mma.run_with_swizzle(mma_params, shared_storage.kernel, swizzle); + + // Next tile + swizzle.problem_visitor.advance(gridDim.x); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/13_two_tensor_op_fusion/reference/device/tensor_scale_bias.h b/examples/13_two_tensor_op_fusion/reference/device/tensor_scale_bias.h index cc33731dd9..e1ba6c563d 100644 --- a/examples/13_two_tensor_op_fusion/reference/device/tensor_scale_bias.h +++ b/examples/13_two_tensor_op_fusion/reference/device/tensor_scale_bias.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -69,7 +69,7 @@ __global__ void TensorScaleBiasGemm( TensorRefScalar tensor_scale, ///< scale tensor TensorRefScalar tensor_bias ///< bias tensor ) { - + ConvertOp convert_op; MatrixCoord output_coord( @@ -89,7 +89,7 @@ __global__ void TensorScaleBiasGemm( ScalarType bias = ScalarType(0); - if(tensor_bias.good()) + if(tensor_bias.good()) bias = tensor_bias.at({0, coord.column()}); tensor_out.at(coord) = convert_op( @@ -99,6 +99,70 @@ __global__ void TensorScaleBiasGemm( } } +template < + typename TensorRefIn, ///< Input TensorRef Type + typename TensorRefOut, ///< Output TensorRef Type + typename ScalarType, ///< alpha Type + typename TensorRefScalar, ///< Scale/Bias TensorRef Type + typename ConvertOp = NumericConverter, + int kMblock = 4, + int kNblock = 4 +> +__global__ void TensorScaleBiasGemmBatched( + gemm::GemmCoord problem_size, + TensorRefIn tensor_in, ///< input tensor + TensorRefOut tensor_out, ///< output tensor + ScalarType alpha, ///< alpha + TensorRefScalar tensor_scale, ///< scale tensor + TensorRefScalar tensor_bias, ///< bias tensor + int batch_count = 1, + int64_t batch_stride_tensor_in = 0, + int64_t batch_stride_tensor_out = 0, + int64_t batch_stride_tensor_scale = 0, + int64_t batch_stride_tensor_bias = 0 +) { + + ConvertOp convert_op; + int row_block = (blockIdx.x * blockDim.x + threadIdx.x) * kMblock; + int col_block = (blockIdx.y * blockDim.y + threadIdx.y) * kNblock; + int batch_idx = blockIdx.z; + + tensor_in.add_pointer_offset(batch_idx * batch_stride_tensor_in); + tensor_out.add_pointer_offset(batch_idx * batch_stride_tensor_out); + tensor_scale.add_pointer_offset(batch_idx * batch_stride_tensor_scale); + tensor_bias.add_pointer_offset(batch_idx * batch_stride_tensor_bias); + + for (; batch_idx < batch_count; batch_idx += gridDim.z) { + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kNblock; j++) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kMblock; i++) { + int row = row_block + i; + int col = col_block + j; + MatrixCoord coord = MatrixCoord(row, col); + if (coord.row() < problem_size.m() && coord.column() < problem_size.n()) { + + ScalarType scale = alpha; + if(tensor_scale.good()) + scale = tensor_scale.at({0, coord.column()}); + + ScalarType bias = ScalarType(0); + + if(tensor_bias.good()) + bias = tensor_bias.at({0, coord.column()}); + + tensor_out.at(coord) = convert_op( + scale * ScalarType(tensor_in.at(coord)) + bias); + } + } + } + tensor_in.add_pointer_offset(batch_stride_tensor_in * gridDim.z); + tensor_out.add_pointer_offset(batch_stride_tensor_out * gridDim.z); + tensor_scale.add_pointer_offset(batch_stride_tensor_scale * gridDim.z); + tensor_bias.add_pointer_offset(batch_stride_tensor_bias * gridDim.z); + } +} + template < typename TensorRefIn, ///< Input TensorRef Type typename TensorRefOut, ///< Output TensorRef Type @@ -118,7 +182,7 @@ __global__ void TensorScaleBiasConv2d( TensorRefScalar tensor_scale, ///< scale tensor TensorRefScalar tensor_bias ///< bias tensor ) { - + ConvertOp convert_op; int64_t npq_start = int64_t(blockIdx.x) * kCtaShapeM * kThreadM + threadIdx.x * kThreadM; @@ -137,7 +201,7 @@ __global__ void TensorScaleBiasConv2d( int64_t npq = npq_start + m; thread_n[m] = int(npq / PQ); - + int64_t residual = npq % PQ; thread_p[m] = int(residual / problem_size.Q); thread_q[m] = int(residual % problem_size.Q); @@ -155,17 +219,17 @@ __global__ void TensorScaleBiasConv2d( ScalarType scale = alpha; if(tensor_scale.good()) scale = tensor_scale.at({0, thread_k}); - + ScalarType bias = ScalarType(0); - if(tensor_bias.good()) + if(tensor_bias.good()) bias = tensor_bias.at({0, thread_k}); - + tensor_out.at({thread_n[m], thread_p[m], thread_q[m], thread_k}) = convert_op( scale * ScalarType( tensor_in.at({thread_n[m], thread_p[m], thread_q[m], thread_k}) ) + bias); } - } + } } } @@ -217,6 +281,62 @@ void TensorScaleBiasGemm( ); } +/// Apply scale and bias on a tensor +template < + typename ElementIn, ///< Input Type + typename ElementOut, ///< Output Type + typename Layout, ///< Layout of input/output tensor + typename ScalarType, ///< alpha Type + typename LayoutScaleBias, ///< Layout of scale and bias + typename ConvertOp = NumericConverter +> +void TensorScaleBiasGemmBatched( + gemm::GemmCoord problem_size, + TensorRef tensor_in, ///< input tensor + TensorRef tensor_out, ///< output tensor + ScalarType alpha, ///< alpha + TensorRef tensor_scale, ///< scale tensor + TensorRef tensor_bias, ///< bias tensor + int batch_count = 1, + int64_t batch_stride_tensor_in = 0, + int64_t batch_stride_tensor_out = 0, + int64_t batch_stride_tensor_scale = 0, + int64_t batch_stride_tensor_bias = 0 +) { + + int const kMblock = 4; + int const kNblock = 4; + + dim3 block(16, 8); + dim3 grid( + (problem_size.m() + block.x * kMblock - 1) / (block.x * kMblock), + (problem_size.n() + block.y * kNblock - 1) / (block.y * kNblock), + batch_count % std::numeric_limits::max() + ); + + kernel::TensorScaleBiasGemmBatched< + TensorRef, + TensorRef, + ScalarType, + TensorRef, + ConvertOp, + kMblock, + kNblock + ><<< grid, block >>> ( + problem_size, + tensor_in, + tensor_out, + alpha, + tensor_scale, + tensor_bias, + batch_count, + batch_stride_tensor_in, + batch_stride_tensor_out, + batch_stride_tensor_scale, + batch_stride_tensor_bias + ); +} + /// Apply scale and bias on a tensor template < typename ElementIn, ///< Input Type diff --git a/examples/13_two_tensor_op_fusion/test_run.h b/examples/13_two_tensor_op_fusion/test_run.h index b14becafc9..2bd6c720a4 100644 --- a/examples/13_two_tensor_op_fusion/test_run.h +++ b/examples/13_two_tensor_op_fusion/test_run.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_multistage.h b/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_multistage.h index 6229b59506..574b123dc7 100644 --- a/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_multistage.h +++ b/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_multistage.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_multistage_smem_accumulator.h b/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_multistage_smem_accumulator.h index ab08c0051d..e7c7ad12dc 100644 --- a/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_multistage_smem_accumulator.h +++ b/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_multistage_smem_accumulator.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_pipelined.h b/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_pipelined.h index 6d0f6db4c6..8313cef8e6 100644 --- a/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_pipelined.h +++ b/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_pipelined.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -321,7 +321,7 @@ class B2bImplicitGemmPipelined : int smem_write_stage_idx = 1; // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing - // shared memory loads (which have the tighest latency requirement). + // shared memory loads (which have the tightest latency requirement). // // Mainloop @@ -461,7 +461,7 @@ class B2bImplicitGemmPipelined : int gemm_k_iterations_1 = FragmentIteratorA1::Policy::kIterations / Base::kWarpGemmIterations1; // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing - // shared memory loads (which have the tighest latency requirement). + // shared memory loads (which have the tightest latency requirement). // // Mainloop diff --git a/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_pipelined_smem_accumulator.h b/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_pipelined_smem_accumulator.h index 04e4f29f43..9775c19e26 100644 --- a/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_pipelined_smem_accumulator.h +++ b/examples/13_two_tensor_op_fusion/threadblock/b2b_implicit_gemm_pipelined_smem_accumulator.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -341,7 +341,7 @@ class B2bImplicitGemmPipelinedSmemAccumulator : int smem_write_stage_idx = 1; // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing - // shared memory loads (which have the tighest latency requirement). + // shared memory loads (which have the tightest latency requirement). // // Mainloop diff --git a/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_base.h b/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_base.h index 36d6d8c9f3..55a41be19c 100644 --- a/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_base.h +++ b/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_base.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_base_smem_accumulator.h b/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_base_smem_accumulator.h index 22643570b2..2d5f616323 100644 --- a/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_base_smem_accumulator.h +++ b/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_base_smem_accumulator.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage.h b/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage.h index 8104f63854..3fb684ce57 100644 --- a/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage.h +++ b/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -119,8 +119,10 @@ class B2bMmaMultistage : using Shape0 = Shape0_; ///< Iterates over tiles of A operand in global memory using IteratorA0 = IteratorA0_; + using IteratorA = IteratorA0; ///< Iterates over tiles of B operand in global memory using IteratorB0 = IteratorB0_; + using IteratorB = IteratorB0; ///< Policy describing tuning details using Policy0 = Policy0_; @@ -139,6 +141,10 @@ class B2bMmaMultistage : using IteratorB1 = IteratorB1_; ///< Policy describing tuning details using Policy1 = Policy1_; + + ///< Export Policy0 as the threadblock-level Mma's policy + using Policy = Policy0; + using Shape = Shape0; using SmemIteratorB1 = SmemIteratorB1_; @@ -188,6 +194,10 @@ class B2bMmaMultistage : /// Complex transform on B operand static ComplexTransform const kTransformB1 = Operator1::kTransformB; + /// Complex transform exports needed by higher-level kernels + static ComplexTransform const kTransformA = kTransformA0; + static ComplexTransform const kTransformB = kTransformB0; + /// Internal structure exposed for introspection. struct Detail { @@ -199,15 +209,15 @@ class B2bMmaMultistage : "GEMM operations."); /// Number of cp.async instructions to load one stage of operand A - static int const TBLDGSTSIterationsA0 = + static int const TBLoadIterationsA0 = IteratorA0::ThreadMap::Iterations::kCount; /// Number of cp.async instructions to load one stage of operand B - static int const TBLDGSTSIterationsB0 = + static int const TBLoadIterationsB0 = IteratorB0::ThreadMap::Iterations::kCount; /// Number of cp.async instructions to load one stage of operand B - static int const TBLDGSTSIterationsB1 = + static int const TBLoadIterationsB1 = IteratorB1::ThreadMap::Iterations::kCount; /// Number of stages @@ -215,15 +225,15 @@ class B2bMmaMultistage : /// Number of cp.async instructions to load on group of operand A static int const kAccessesPerGroupA0 = - (TBLDGSTSIterationsA0 + Base::kWarpGemmIterations0 - 1) / Base::kWarpGemmIterations0; + (TBLoadIterationsA0 + Base::kWarpGemmIterations0 - 1) / Base::kWarpGemmIterations0; /// Number of cp.async instructions to load on group of operand B static int const kAccessesPerGroupB0 = - (TBLDGSTSIterationsB0 + Base::kWarpGemmIterations0 - 1) / Base::kWarpGemmIterations0; + (TBLoadIterationsB0 + Base::kWarpGemmIterations0 - 1) / Base::kWarpGemmIterations0; /// Number of cp.async instructions to load on group of operand B static int const kAccessesPerGroupB1 = - (TBLDGSTSIterationsB1 + Base::kWarpGemmIterations1 - 1) / Base::kWarpGemmIterations1; + (TBLoadIterationsB1 + Base::kWarpGemmIterations1 - 1) / Base::kWarpGemmIterations1; }; private: @@ -267,7 +277,9 @@ class B2bMmaMultistage : ///< ID of warp int warp_idx, ///< ID of each thread within a warp - int lane_idx + int lane_idx, + ///< GEMM0 N is used for accumulator extent + int problem_size_0_n ): Base(shared_storage, thread_idx, warp_idx, lane_idx), smem_iterator_A0_(shared_storage.shared_storage0.operand_A_ref(), thread_idx), @@ -302,10 +314,10 @@ class B2bMmaMultistage : IteratorA0::kAccessesPerVector); this->smem_iterator_A0_.set_iteration_index(group_start_A0); - // LDGSTS for operand A + // Load for operand A CUTLASS_PRAGMA_UNROLL for (int j = 0; j < Detail::kAccessesPerGroupA0; ++j) { - if (group_start_A0 + j < Detail::TBLDGSTSIterationsA0) { + if (group_start_A0 + j < Detail::TBLoadIterationsA0) { typename IteratorA0::AccessType *dst_ptr = reinterpret_cast( this->smem_iterator_A0_.get()); @@ -332,10 +344,10 @@ class B2bMmaMultistage : IteratorB0::kAccessesPerVector); this->smem_iterator_B0_.set_iteration_index(group_start_B0); - // LDGSTS for operand B + // Load for operand B CUTLASS_PRAGMA_UNROLL for (int j = 0; j < Detail::kAccessesPerGroupB0; ++j) { - if (group_start_B0 + j < Detail::TBLDGSTSIterationsB0) { + if (group_start_B0 + j < Detail::TBLoadIterationsB0) { typename IteratorB0::AccessType *dst_ptr = reinterpret_cast( this->smem_iterator_B0_.get()); @@ -365,10 +377,10 @@ class B2bMmaMultistage : IteratorB1::kAccessesPerVector); this->smem_iterator_B1_.set_iteration_index(group_start_B1); - // LDGSTS for operand B + // Load for operand B CUTLASS_PRAGMA_UNROLL for (int j = 0; j < Detail::kAccessesPerGroupB1; ++j) { - if (group_start_B1 + j < Detail::TBLDGSTSIterationsB1) { + if (group_start_B1 + j < Detail::TBLoadIterationsB1) { typename IteratorB1::AccessType *dst_ptr = reinterpret_cast( this->smem_iterator_B1_.get()); @@ -428,9 +440,9 @@ class B2bMmaMultistage : iterator_A0.set_iteration_index(0); this->smem_iterator_A0_.set_iteration_index(0); - // LDGSTS for operand A + // Load for operand A CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::TBLDGSTSIterationsA0; ++j) { + for (int j = 0; j < Detail::TBLoadIterationsA0; ++j) { typename IteratorA0::AccessType *dst_ptr = reinterpret_cast( this->smem_iterator_A0_.get()); @@ -456,9 +468,9 @@ class B2bMmaMultistage : iterator_B0.set_iteration_index(0); this->smem_iterator_B0_.set_iteration_index(0); - // LDGSTS for operand B + // Load for operand B CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::TBLDGSTSIterationsB0; ++j) { + for (int j = 0; j < Detail::TBLoadIterationsB0; ++j) { typename IteratorB0::AccessType *dst_ptr = reinterpret_cast( this->smem_iterator_B0_.get()); @@ -639,6 +651,10 @@ class B2bMmaMultistage : } + // Commit and drain all pending and predicated cp.async pnz from the GEMM mainloop + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); // 2nd Gemm @@ -657,12 +673,11 @@ class B2bMmaMultistage : tb_frag_A1_bias.clear(); iterator_A1_bias.load(tb_frag_A1_bias); ++iterator_A1_bias; - - + // // Prologue // - int gemm_k_iterations_1 = FragmentIteratorA1::Policy::kIterations / Base::kWarpGemmIterations1; + int gemm_k_iterations_1 = (FragmentIteratorA1::Policy::kIterations + Base::kWarpGemmIterations1 - 1) / Base::kWarpGemmIterations1; // Issue several complete stages CUTLASS_PRAGMA_UNROLL @@ -674,9 +689,9 @@ class B2bMmaMultistage : iterator_B1.set_iteration_index(0); this->smem_iterator_B1_.set_iteration_index(0); - // LDGSTS for operand B + // Load for operand B CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::TBLDGSTSIterationsB1; ++j) { + for (int j = 0; j < Detail::TBLoadIterationsB1; ++j) { typename IteratorB1::AccessType *dst_ptr = reinterpret_cast( this->smem_iterator_B1_.get()); @@ -750,9 +765,9 @@ class B2bMmaMultistage : // Mainloop // + gemm_k_iterations_1 = (FragmentIteratorA1::Policy::kIterations + Base::kWarpGemmIterations1 - 1) / Base::kWarpGemmIterations1 - (Base::kStages - 1); CUTLASS_PRAGMA_UNROLL - for (gemm_k_iterations_1 = FragmentIteratorA1::Policy::kIterations / Base::kWarpGemmIterations1 - (Base::kStages - 1); - gemm_k_iterations_1 > (-Base::kStages + 1); gemm_k_iterations_1--) { + for (; gemm_k_iterations_1 > (-Base::kStages + 1); gemm_k_iterations_1--) { // // Loop over GEMM K dimension // @@ -871,7 +886,10 @@ class B2bMmaMultistage : } - + // Commit and drain all pending and predicated cp.async pnz from the GEMM mainloop + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); } }; diff --git a/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage_smem_accumulator.h b/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage_smem_accumulator.h index c28f4e49cd..35c4f5cc3f 100644 --- a/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage_smem_accumulator.h +++ b/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage_smem_accumulator.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -121,8 +121,10 @@ class B2bMmaMultistageSmemAccumulator : using Shape0 = Shape0_; ///< Iterates over tiles of A operand in global memory using IteratorA0 = IteratorA0_; + using IteratorA = IteratorA0; ///< Iterates over tiles of B operand in global memory using IteratorB0 = IteratorB0_; + using IteratorB = IteratorB0; ///< Iterates over tiles of the scale and bias vectors in global memory using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_; ///< Policy describing tuning details @@ -141,6 +143,10 @@ class B2bMmaMultistageSmemAccumulator : ///< Policy describing tuning details using Policy1 = Policy1_; + ///< Export Policy0 as the threadblock-level Mma's policy + using Policy = Policy0; + using Shape = Shape0; + using SmemIteratorB1 = SmemIteratorB1_; using WarpIteratorA1 = WarpIteratorA1_; ///< Iterates over the intermediate accumulator tile in shared memory @@ -194,6 +200,10 @@ class B2bMmaMultistageSmemAccumulator : /// Complex transform on B operand static ComplexTransform const kTransformB1 = Operator1::kTransformB; + /// Complex transform exports needed by higher-level kernels + static ComplexTransform const kTransformA = kTransformA0; + static ComplexTransform const kTransformB = kTransformB0; + /// Internal structure exposed for introspection. struct Detail { @@ -205,15 +215,15 @@ class B2bMmaMultistageSmemAccumulator : "GEMM operations."); /// Number of cp.async instructions to load one stage of operand A - static int const TBLDGSTSIterationsA0 = + static int const TBLoadIterationsA0 = IteratorA0::ThreadMap::Iterations::kCount; /// Number of cp.async instructions to load one stage of operand B - static int const TBLDGSTSIterationsB0 = + static int const TBLoadIterationsB0 = IteratorB0::ThreadMap::Iterations::kCount; /// Number of cp.async instructions to load one stage of operand B - static int const TBLDGSTSIterationsB1 = + static int const TBLoadIterationsB1 = IteratorB1::ThreadMap::Iterations::kCount; /// Number of stages @@ -221,15 +231,15 @@ class B2bMmaMultistageSmemAccumulator : /// Number of cp.async instructions to load on group of operand A static int const kAccessesPerGroupA0 = - (TBLDGSTSIterationsA0 + Base::kWarpGemmIterations0 - 1) / Base::kWarpGemmIterations0; + (TBLoadIterationsA0 + Base::kWarpGemmIterations0 - 1) / Base::kWarpGemmIterations0; /// Number of cp.async instructions to load on group of operand B static int const kAccessesPerGroupB0 = - (TBLDGSTSIterationsB0 + Base::kWarpGemmIterations0 - 1) / Base::kWarpGemmIterations0; + (TBLoadIterationsB0 + Base::kWarpGemmIterations0 - 1) / Base::kWarpGemmIterations0; /// Number of cp.async instructions to load on group of operand B static int const kAccessesPerGroupB1 = - (TBLDGSTSIterationsB1 + Base::kWarpGemmIterations1 - 1) / Base::kWarpGemmIterations1; + (TBLoadIterationsB1 + Base::kWarpGemmIterations1 - 1) / Base::kWarpGemmIterations1; }; private: @@ -276,13 +286,15 @@ class B2bMmaMultistageSmemAccumulator : ///< ID of warp int warp_idx, ///< ID of each thread within a warp - int lane_idx + int lane_idx, + ///< GEMM0 N is used for accumulator extent + int problem_size_0_n ): Base(shared_storage, thread_idx, warp_idx, lane_idx), smem_iterator_A0_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_A_ref(), thread_idx), smem_iterator_B0_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_B_ref(), thread_idx), smem_iterator_D0_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx), - warp_tile_iterator_A1_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx), + warp_tile_iterator_A1_(shared_storage.accumulator_shared_storage0.accum_ref(), {Base::WarpGemm1::kM, problem_size_0_n}, lane_idx ), smem_iterator_B1_(shared_storage.b2b_mma_shared_storage.shared_storage1.operand_B_ref(), thread_idx) { // Compute warp location within threadblock tile by mapping the warp_id to @@ -325,10 +337,10 @@ class B2bMmaMultistageSmemAccumulator : IteratorA0::kAccessesPerVector); this->smem_iterator_A0_.set_iteration_index(group_start_A0); - // LDGSTS for operand A + // cp.async for operand A CUTLASS_PRAGMA_UNROLL for (int j = 0; j < Detail::kAccessesPerGroupA0; ++j) { - if (group_start_A0 + j < Detail::TBLDGSTSIterationsA0) { + if (group_start_A0 + j < Detail::TBLoadIterationsA0) { typename IteratorA0::AccessType *dst_ptr = reinterpret_cast( this->smem_iterator_A0_.get()); @@ -355,10 +367,10 @@ class B2bMmaMultistageSmemAccumulator : IteratorB0::kAccessesPerVector); this->smem_iterator_B0_.set_iteration_index(group_start_B0); - // LDGSTS for operand B + // cp.async for operand B CUTLASS_PRAGMA_UNROLL for (int j = 0; j < Detail::kAccessesPerGroupB0; ++j) { - if (group_start_B0 + j < Detail::TBLDGSTSIterationsB0) { + if (group_start_B0 + j < Detail::TBLoadIterationsB0) { typename IteratorB0::AccessType *dst_ptr = reinterpret_cast( this->smem_iterator_B0_.get()); @@ -388,10 +400,10 @@ class B2bMmaMultistageSmemAccumulator : IteratorB1::kAccessesPerVector); this->smem_iterator_B1_.set_iteration_index(group_start_B1); - // LDGSTS for operand B + // cp.async for operand B CUTLASS_PRAGMA_UNROLL for (int j = 0; j < Detail::kAccessesPerGroupB1; ++j) { - if (group_start_B1 + j < Detail::TBLDGSTSIterationsB1) { + if (group_start_B1 + j < Detail::TBLoadIterationsB1) { typename IteratorB1::AccessType *dst_ptr = reinterpret_cast( this->smem_iterator_B1_.get()); @@ -451,9 +463,9 @@ class B2bMmaMultistageSmemAccumulator : iterator_A0.set_iteration_index(0); this->smem_iterator_A0_.set_iteration_index(0); - // LDGSTS for operand A + // cp.async for operand A CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::TBLDGSTSIterationsA0; ++j) { + for (int j = 0; j < Detail::TBLoadIterationsA0; ++j) { typename IteratorA0::AccessType *dst_ptr = reinterpret_cast( this->smem_iterator_A0_.get()); @@ -479,9 +491,9 @@ class B2bMmaMultistageSmemAccumulator : iterator_B0.set_iteration_index(0); this->smem_iterator_B0_.set_iteration_index(0); - // LDGSTS for operand B + // cp.async for operand B CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::TBLDGSTSIterationsB0; ++j) { + for (int j = 0; j < Detail::TBLoadIterationsB0; ++j) { typename IteratorB0::AccessType *dst_ptr = reinterpret_cast( this->smem_iterator_B0_.get()); @@ -662,6 +674,11 @@ class B2bMmaMultistageSmemAccumulator : } + // Insert fence and wait for all outstanding cp.async operations to commit. + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + /// Epilogue for the first Implicit Gemm Epilogue0 epilogue0; @@ -687,9 +704,9 @@ class B2bMmaMultistageSmemAccumulator : iterator_B1.set_iteration_index(0); this->smem_iterator_B1_.set_iteration_index(0); - // LDGSTS for operand B + // cp.async for operand B CUTLASS_PRAGMA_UNROLL - for (int j = 0; j < Detail::TBLDGSTSIterationsB1; ++j) { + for (int j = 0; j < Detail::TBLoadIterationsB1; ++j) { typename IteratorB1::AccessType *dst_ptr = reinterpret_cast( this->smem_iterator_B1_.get()); @@ -853,7 +870,10 @@ class B2bMmaMultistageSmemAccumulator : } - + // Commit and drain all pending and predicated cp.async pnz from the GEMM mainloop + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); } }; diff --git a/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined.h b/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined.h index 4e39fda5b6..d5f1629466 100644 --- a/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined.h +++ b/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -126,7 +126,9 @@ class B2bMmaPipelined : using Shape0 = Shape0_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> using IteratorA0 = IteratorA0_; ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA0; using IteratorB0 = IteratorB0_; ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB0; using Policy0 = Policy0_; ///< Policy describing tuning details using SmemIteratorA0 = SmemIteratorA0_; @@ -139,6 +141,8 @@ class B2bMmaPipelined : FragmentIteratorA1ScaleBias_; ///< WarpIterator to load Scale or Bias vector from the threadblock fragment using IteratorB1 = IteratorB1_; ///< Iterates over tiles of B operand in global memory using Policy1 = Policy1_; ///< Policy describing tuning details + using Policy = Policy1; ///< Export Policy1 as the threadblock-level Mma's policy + using Shape = Shape1; using SmemIteratorB1 = SmemIteratorB1_; @@ -195,6 +199,10 @@ class B2bMmaPipelined : /// Complex transform on B1 operand static ComplexTransform const kTransformB1 = Operator1::kTransformB; + /// Complex transform exports needed by higher-level kernels + static ComplexTransform const kTransformA = kTransformA0; + static ComplexTransform const kTransformB = kTransformB0; + /// staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline) static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2"); @@ -228,7 +236,8 @@ class B2bMmaPipelined : typename Base::B2bMmaSharedStorage &shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM int thread_idx, ///< ID within the threadblock int warp_idx, ///< ID of warp - int lane_idx ///< ID of each thread within a warp + int lane_idx, ///< ID of each thread within a warp + int problem_size_0_n ///< GEMM0 N is used for accumulator extent ): Base(shared_storage, thread_idx, warp_idx, lane_idx), smem_iterator_A_(shared_storage.shared_storage0.operand_A_ref(), thread_idx), @@ -324,7 +333,7 @@ class B2bMmaPipelined : iterator_B0.clear_mask(gemm_k_iterations_0 <= 1); // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing - // shared memory loads (which have the tighest latency requirement). + // shared memory loads (which have the tightest latency requirement). // // Mainloop diff --git a/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined_smem_accumulator.h b/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined_smem_accumulator.h index b548c85763..c3393e0ccf 100644 --- a/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined_smem_accumulator.h +++ b/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined_smem_accumulator.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -128,7 +128,9 @@ class B2bMmaPipelinedSmemAccumulator : using Shape0 = Shape0_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> using IteratorA0 = IteratorA0_; ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA0; using IteratorB0 = IteratorB0_; ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB0; using IteratorAccumulatorScaleBias = IteratorAccumulatorScaleBias_; ///< Iterates over tiles of the scale and bias vectors in global memory using Policy0 = Policy0_; ///< Policy0 describing tuning details @@ -141,6 +143,8 @@ class B2bMmaPipelinedSmemAccumulator : using Shape1 = Shape1_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> using IteratorB1 = IteratorB1_; ///< Iterates over tiles of B operand in global memory using Policy1 = Policy1_; ///< Policy1 describing tuning details + using Policy = Policy1; ///< Export Policy1 as the threadblock-level Mma's policy + using Shape = Shape1; using SmemIteratorB1 = SmemIteratorB1_; using WarpIteratorA1 = WarpIteratorA1_; ///< Iterates over the intermediate accumulator tile in shared memory @@ -192,6 +196,10 @@ class B2bMmaPipelinedSmemAccumulator : /// Complex transform on B1 operand static ComplexTransform const kTransformB1 = Operator1::kTransformB; + /// Complex transform exports needed by higher-level kernels + static ComplexTransform const kTransformA = kTransformA0; + static ComplexTransform const kTransformB = kTransformB0; + /// staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline) static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2"); @@ -236,13 +244,14 @@ class B2bMmaPipelinedSmemAccumulator : typename Base::B2bMmaSharedStorage &shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM int thread_idx, ///< ID within the threadblock int warp_idx, ///< ID of warp - int lane_idx ///< ID of each thread within a warp + int lane_idx, ///< ID of each thread within a warp + int problem_size_0_n ///< GEMM0 N is used for accumulator extent ): Base(shared_storage, thread_idx, warp_idx, lane_idx), smem_iterator_A_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_A_ref(), thread_idx), smem_iterator_B0_(shared_storage.b2b_mma_shared_storage.shared_storage0.operand_B_ref(), thread_idx), smem_iterator_D0_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx), - warp_tile_iterator_A1_(shared_storage.accumulator_shared_storage0.accum_ref(), lane_idx), + warp_tile_iterator_A1_(shared_storage.accumulator_shared_storage0.accum_ref(), {Base::WarpGemm1::kM, problem_size_0_n}, lane_idx), smem_iterator_B1_(shared_storage.b2b_mma_shared_storage.shared_storage1.operand_B_ref(), thread_idx) { // Compute warp location within threadblock tile by mapping the warp_id to @@ -345,7 +354,7 @@ class B2bMmaPipelinedSmemAccumulator : iterator_B0.clear_mask(gemm_k_iterations_0 <= 1); // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing - // shared memory loads (which have the tighest latency requirement). + // shared memory loads (which have the tightest latency requirement). // // Mainloop diff --git a/examples/13_two_tensor_op_fusion/threadblock/default_b2b_mma.h b/examples/13_two_tensor_op_fusion/threadblock/default_b2b_mma.h index 3c12e05cc0..2ea38cebfa 100644 --- a/examples/13_two_tensor_op_fusion/threadblock/default_b2b_mma.h +++ b/examples/13_two_tensor_op_fusion/threadblock/default_b2b_mma.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/examples/13_two_tensor_op_fusion/threadblock/default_b2b_mma_smem_accumulator.h b/examples/13_two_tensor_op_fusion/threadblock/default_b2b_mma_smem_accumulator.h index ea1a258fbd..7a97ce0312 100644 --- a/examples/13_two_tensor_op_fusion/threadblock/default_b2b_mma_smem_accumulator.h +++ b/examples/13_two_tensor_op_fusion/threadblock/default_b2b_mma_smem_accumulator.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -43,7 +43,7 @@ #include "cutlass/gemm/threadblock/default_mma_core_sm70.h" #include "cutlass/gemm/threadblock/default_mma_core_sm75.h" #include "cutlass/gemm/threadblock/default_mma_core_sm80.h" -#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h" #include "threadblock/b2b_mma_pipelined_smem_accumulator.h" #include "threadblock/b2b_mma_multistage_smem_accumulator.h" @@ -158,11 +158,11 @@ struct DefaultB2bMma, cutlass::gemm::Operand::kA, + using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileAccessIterator< + MatrixShape, cutlass::gemm::Operand::kA, ElementA, SmemAccumulatorLayout, MatrixShape, - WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>; + WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount, true>; // Define the threadblock-scoped pipelined matrix multiply using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaPipelinedSmemAccumulator< @@ -303,11 +303,11 @@ struct DefaultB2bMma, cutlass::gemm::Operand::kA, + using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileAccessIterator< + MatrixShape, cutlass::gemm::Operand::kA, ElementA, SmemAccumulatorLayout, MatrixShape, - WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>; + WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount, true>; // Define the threadblock-scoped pipelined matrix multiply using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaMultistageSmemAccumulator< @@ -436,11 +436,11 @@ struct DefaultB2bMma, cutlass::gemm::Operand::kA, + using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileAccessIterator< + MatrixShape, cutlass::gemm::Operand::kA, ElementA, SmemAccumulatorLayout, MatrixShape, - WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>; + WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount, true>; // Define the threadblock-scoped pipelined matrix multiply using ThreadblockB2bMma = cutlass::gemm::threadblock::B2bMmaPipelinedSmemAccumulator< @@ -574,11 +574,11 @@ struct DefaultB2bMma, cutlass::gemm::Operand::kA, + using WarpIteratorA1 = cutlass::gemm::warp::MmaTensorOpMultiplicandTileAccessIterator< + MatrixShape, cutlass::gemm::Operand::kA, ElementA, SmemAccumulatorLayout, MatrixShape, - WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount>; + WarpMmaTensorOp1::Policy::OpDelta::kRow, kThreadCount, true >; // Define the threadblock-scoped multistage matrix multiply diff --git a/examples/13_two_tensor_op_fusion/threadblock/grouped_threadblock_swizzle.h b/examples/13_two_tensor_op_fusion/threadblock/grouped_threadblock_swizzle.h new file mode 100644 index 0000000000..c79b7c77f9 --- /dev/null +++ b/examples/13_two_tensor_op_fusion/threadblock/grouped_threadblock_swizzle.h @@ -0,0 +1,125 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implements several threadblock-swizzling functions for grouped kernels +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/kernel/grouped_problem_visitor.h" +#include "cutlass/gemm/kernel/gemm_grouped_problem_visitor.h" +#include "kernel/b2b_gemm_grouped_problem_visitor.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +struct GroupedThreadblockSwizzleBase {}; + +/// Helper for determining if a swizzling function is specialized for grouped operation +template +struct IsGroupedSwizzle { + static bool const value = cutlass::platform::is_base_of::value; +}; + +} // namespace detail + +/// Swizzling function for grouped kernels +template +struct GroupedThreadblockSwizzle : detail::GroupedThreadblockSwizzleBase { + + using ProblemVisitor = ProblemVisitor_; + ProblemVisitor problem_visitor; + + CUTLASS_HOST_DEVICE + GroupedThreadblockSwizzle(typename ProblemVisitor::Params& params, + typename ProblemVisitor::SharedStorage& shared_storage, + int block_idx) : problem_visitor(params, shared_storage, block_idx) {} + + /// Obtains the threadblock offset (in units of threadblock-scoped tiles) + CUTLASS_DEVICE + GemmCoord get_tile_offset(int /*log_tile*/) const { + GemmCoord problem_size = problem_visitor.problem_size(); + int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx()); + GemmCoord grid_shape = problem_visitor.grid_shape(problem_size); + + return GemmCoord(int(threadblock_idx / grid_shape.n()), + int(threadblock_idx % grid_shape.n()), + 0); + } + + /// Dummy method to satisfy API for threadblock swizzling functions + CUTLASS_HOST_DEVICE + static int get_log_tile(GemmCoord /*tiled_shape*/) { + return 0; + } +}; + +template < + typename ThreadblockShape, + typename LayoutC, + cutlass::gemm::kernel::GroupScheduleMode GroupScheduleMode_ = cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly, + int PrefetchTileCount = 128, + int ThreadCount = PrefetchTileCount> +struct B2bGemmGroupedThreadblockSwizzle : GroupedThreadblockSwizzle< + cutlass::gemm::kernel::B2bGemmGroupedProblemVisitor< + ThreadblockShape, + GroupScheduleMode_, + PrefetchTileCount, + ThreadCount, + platform::is_same::value + > + > { + using Base = GroupedThreadblockSwizzle::value>>; + + CUTLASS_HOST_DEVICE + B2bGemmGroupedThreadblockSwizzle(typename Base::ProblemVisitor::Params& params, + typename Base::ProblemVisitor::SharedStorage& shared_storage, + int block_idx) : Base(params, shared_storage, block_idx) {} +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/examples/14_ampere_tf32_tensorop_gemm/CMakeLists.txt b/examples/14_ampere_tf32_tensorop_gemm/CMakeLists.txt index 23b8f0dd46..3e0b870f30 100644 --- a/examples/14_ampere_tf32_tensorop_gemm/CMakeLists.txt +++ b/examples/14_ampere_tf32_tensorop_gemm/CMakeLists.txt @@ -1,5 +1,5 @@ -# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without diff --git a/examples/14_ampere_tf32_tensorop_gemm/ampere_tf32_tensorop_gemm.cu b/examples/14_ampere_tf32_tensorop_gemm/ampere_tf32_tensorop_gemm.cu index eed9f9b299..99d3cdb178 100644 --- a/examples/14_ampere_tf32_tensorop_gemm/ampere_tf32_tensorop_gemm.cu +++ b/examples/14_ampere_tf32_tensorop_gemm/ampere_tf32_tensorop_gemm.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -194,7 +194,7 @@ using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 16>; // <- warp tile M = using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 8>; // <- MMA Op tile M = 16, N = 8, K = 8 // This code section describes how threadblocks are scheduled on GPU -using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? +using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // This code section describes the epilogue part of the kernel using EpilogueOp = cutlass::epilogue::thread::LinearCombination< diff --git a/examples/15_ampere_sparse_tensorop_gemm/CMakeLists.txt b/examples/15_ampere_sparse_tensorop_gemm/CMakeLists.txt index fbd20df8a8..02d3205889 100644 --- a/examples/15_ampere_sparse_tensorop_gemm/CMakeLists.txt +++ b/examples/15_ampere_sparse_tensorop_gemm/CMakeLists.txt @@ -1,5 +1,5 @@ -# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without @@ -33,3 +33,12 @@ cutlass_example_add_executable( ampere_sparse_tensorop_gemm.cu ) +cutlass_example_add_executable( + 15_ampere_sparse_tensorop_gemm_universal + ampere_sparse_tensorop_gemm_universal.cu + ) + +cutlass_example_add_executable( + 15_ampere_sparse_tensorop_gemm_with_visitor + ampere_sparse_tensorop_gemm_with_visitor.cu + ) diff --git a/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm.cu b/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm.cu index 4a25a96341..e92b717caa 100644 --- a/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm.cu +++ b/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -84,7 +84,7 @@ using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 256>; // <- warp tile M = using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 128>; // <- MMA Op tile M = 16, N = 8, K = 128 // This code section describes how threadblocks are scheduled on GPU -using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? +using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // This code section describes the epilogue part of the kernel using EpilogueOp = cutlass::epilogue::thread::LinearCombination< diff --git a/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm_universal.cu b/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm_universal.cu new file mode 100644 index 0000000000..dcab5ac144 --- /dev/null +++ b/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm_universal.cu @@ -0,0 +1,329 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/** +Please check example 07, 08 and 17 for the basics of dense tensor op gemm kernels. NVIDIA Ampere +architecture also supports structured sparse tensor op for tf32, fp16, int8 and int4. + +Sparse GEMM kernels needs to takes an additional E matrix which stores the meta data. The format of +meta data is different for every data types. CUTLASS templates can automatically infer it based on +input A and B. Check code below. + +Moreover, matrix E needs to be preprocessed so that it can use ldmatrix to load into the registers +efficiently. +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm_sparse_universal.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/host_reorder.h" +#include "cutlass/util/host_uncompress.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" +#include "helper.h" + +// The code section below describes datatype for input, output matrices and computation between +// elements in input matrices. +using ElementAccumulator = int32_t; // <- data type of accumulator +using ElementComputeEpilogue = ElementAccumulator; // <- data type of epilogue operations +using ElementInputA = cutlass::int4b_t; // <- data type of elements in input matrix A +using ElementInputB = cutlass::int4b_t; // <- data type of elements in input matrix B +using ElementOutput = int32_t; // <- data type of elements in output matrix D + +// The code section below describes matrix layout of input and output matrices. Row Major for +// Matrix A, Column Major for Matrix B and Row Major for Matrix C +using LayoutInputA = cutlass::layout::RowMajor; +using LayoutInputB = cutlass::layout::ColumnMajor; +using LayoutOutput = cutlass::layout::RowMajor; + +// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM +using MMAOp = cutlass::arch::OpClassTensorOp; + +// This code section describes CUDA SM architecture number +using SmArch = cutlass::arch::Sm80; + +// This code section describes the tile size a thread block will compute +using ShapeMMAThreadBlock = + cutlass::gemm::GemmShape<128, 128, 256>; // <- threadblock tile M = 128, N = 128, K = 256 +// This code section describes tile size a warp will compute +using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 256>; // <- warp tile M = 64, N = 64, K = 256 +// This code section describes the size of MMA op +using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 128>; // <- MMA Op tile M = 16, N = 8, K = 128 + +// This code section describes how threadblocks are scheduled on GPU +using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; + +// This code section describes the epilogue part of the kernel +using EpilogueOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, // <- data type of output matrix + 128 / cutlass::sizeof_bits::value, // <- the number of elements per vectorized + // memory access. For a byte, it's 16 + // elements. This becomes the vector width of + // math instructions in the epilogue too + ElementAccumulator, // <- data type of accumulator + ElementComputeEpilogue>; // <- data type for alpha/beta in linear combination function + +// Number of pipelines you want to use +constexpr int NumStages = 3; + +using Gemm = cutlass::gemm::device::GemmSparseUniversal; + +// Data type and layout of meta data matrix E can be inferred from template Gemm. +using ElementInputE = typename Gemm::ElementE; +using LayoutInputE = cutlass::layout::RowMajor; +using ReorderedLayoutInputE = typename Gemm::LayoutE; + +// Blow property is defined in include/cutlass/arch/sp_mma_sm80.h +// 50% Sparsity on Ampere +constexpr int kSparse = Gemm::kSparse; +// How many elements of A are covered per ElementE +constexpr int kElementsPerElementE = Gemm::kElementsPerElementE; +// The size of individual meta data +constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits; + +int run() { + + const int length_m = 512; + const int length_n = 512; + const int length_k = 1024; + + // Create a tuple of problem size for matrix multiplication + cutlass::gemm::GemmCoord problem_size(length_m, length_n, length_k); + + // Initialize tensors using CUTLASS helper functions + cutlass::HostTensor tensor_a( + cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse)); // <- Create matrix A with dimensions M x (K / 2) + cutlass::HostTensor tensor_a_uncompressed( + problem_size.mk()); // <- Create uncompressed matrix A with dimensions M x K for reference computing + + cutlass::HostTensor tensor_b( + problem_size.kn()); // <- Create matrix B with dimensions K x N + cutlass::HostTensor tensor_c( + problem_size.mn()); // <- Create matrix C with dimensions M x N + cutlass::HostTensor tensor_d( + problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from + // CUTLASS kernel + cutlass::HostTensor tensor_ref_d( + problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from + // reference kernel + + // Create matrix E with dimensions M x (K / 2 / kElementsPerElementE). This one is used by reference computing. + cutlass::HostTensor tensor_e( + cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE)); + // Same size as the above. The above one needs to be reordered and stored in this one. + cutlass::HostTensor tensor_e_reordered( + cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE)); + + // Fill input and output matrices on host using CUTLASS helper functions + cutlass::reference::host::TensorFillRandomUniform( + tensor_a.host_view(), + 1, + ElementInputA(2), + ElementInputA(-2), + 0); // <- Fill matrix A on host with uniform-distribution random data + cutlass::reference::host::TensorFillRandomUniform( + tensor_b.host_view(), + 1, + ElementInputB(2), + ElementInputB(-2), + 0); // <- Fill matrix B on host with uniform-distribution random data + cutlass::reference::host::TensorFillRandomUniform( + tensor_c.host_view(), + 1, + ElementOutput(2), + ElementOutput(-2), + 0); // <- Fill matrix C on host with uniform-distribution random data + cutlass::reference::host::TensorFillRandomSparseMeta( + tensor_e.host_view(), + 1, + kMetaSizeInBits); // <- Fill matrix E on host with uniform-distribution random meta data + cutlass::reference::host::TensorFill( + tensor_d.host_view()); // <- fill matrix D on host with zeros + cutlass::reference::host::TensorFill( + tensor_ref_d.host_view()); // <- fill matrix D for reference on host with zeros + + // Reorder the meta data matrix so that we can use ldmatrix to load them to tensor core + // instructions. + cutlass::reorder_meta(tensor_e_reordered.host_ref(), tensor_e.host_ref(), + {problem_size.m(), problem_size.n(), + problem_size.k() / kSparse / kElementsPerElementE}); + + // Copy data from host to GPU + tensor_a.sync_device(); + tensor_b.sync_device(); + tensor_c.sync_device(); + tensor_d.sync_device(); + tensor_e_reordered.sync_device(); + tensor_ref_d.sync_device(); + + // Initialize alpha and beta for dot product computation + ElementComputeEpilogue alpha = ElementComputeEpilogue(1); + ElementComputeEpilogue beta = ElementComputeEpilogue(0); + + // Split K dimension into 1 partitions + int split_k_slices = 2; + + // Create a tuple of gemm kernel arguments. This is later passed as arguments to launch + // instantiated CUTLASS kernel + typename Gemm::Arguments arguments{cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, // <- problem size of matrix multiplication + split_k_slices,// <- k-dimension split factor + {alpha, beta}, // <- tuple of alpha and beta + tensor_a.device_data(), // <- reference to matrix A on device + tensor_b.device_data(), // <- reference to matrix B on device + tensor_c.device_data(), // <- reference to matrix C on device + tensor_d.device_data(), // <- reference to matrix D on device + tensor_e_reordered.device_data(), // <- reference to matrix E on device + int64_t(), + int64_t(), + int64_t(), + int64_t(), + int64_t(), + tensor_a.layout().stride(0), + tensor_b.layout().stride(0), + tensor_c.layout().stride(0), + tensor_d.layout().stride(0), + tensor_e_reordered.layout().stride(0) + }; + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm_op; + + // Check the problem size is supported or not + cutlass::Status status = gemm_op.can_implement(arguments); + CUTLASS_CHECK(status); + + // Initialize CUTLASS kernel with arguments and workspace pointer + status = gemm_op.initialize(arguments, workspace.get()); + CUTLASS_CHECK(status); + + // Launch initialized CUTLASS kernel + status = gemm_op(); + CUTLASS_CHECK(status); + + // uncompress tensor_a based on meta data tensor_e. We need it for reference computing. + cutlass::uncompress(tensor_a_uncompressed.host_ref(), tensor_a.host_ref(), + tensor_e.host_ref(), problem_size.m(), problem_size.k()); + + // Create instantiation for host reference gemm kernel + cutlass::reference::host::Gemm + gemm_host; + + // Launch host reference gemm kernel + gemm_host(problem_size, + alpha, + tensor_a_uncompressed.host_ref(), + tensor_b.host_ref(), + beta, + tensor_c.host_ref(), + tensor_ref_d.host_ref()); + + // Copy output data from CUTLASS host for comparison + tensor_d.sync_host(); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + bool passed = cutlass::reference::host::TensorEquals( + tensor_d.host_view(), + tensor_ref_d.host_view()); + + std::cout << (passed ? "Passed" : "Failed") << std::endl; + + return (passed ? 0 : -1); +} + +int main() { + + bool notSupported = false; + + // Ampere Sparse Tensor Core operations exposed with mma.sync and ldmatrix are first available + // in CUDA 11.1. + // + // CUTLASS must be compiled with CUDA 11.1 Toolkit to run these examples. + + if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 1))) { + std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.1 Toolkit or later." << std::endl; + notSupported = true; + } + + cudaDeviceProp props; + + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if (props.major * 10 + props.minor < 80) { + std::cerr << "Ampere Tensor Core operations must be run on a machine with compute capability at least 80." + << std::endl; + notSupported = true; + } + + if (notSupported) { + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + return run(); +} diff --git a/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm_with_visitor.cu b/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm_with_visitor.cu new file mode 100644 index 0000000000..90aa44528e --- /dev/null +++ b/examples/15_ampere_sparse_tensorop_gemm/ampere_sparse_tensorop_gemm_with_visitor.cu @@ -0,0 +1,377 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/** +Please check example 07, 08 and 17 for the basics of dense tensor op gemm kernels. NVIDIA Ampere +architecture also supports structured sparse tensor op for tf32, fp16, int8 and int4. +Sparse GEMM kernels needs to takes an additional E matrix which stores the meta data. The format of +meta data is different for every data types. CUTLASS templates can automatically infer it based on +input A and B. Check code below. +Moreover, matrix E needs to be preprocessed so that it can use ldmatrix to load into the registers +efficiently. +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm_sparse_with_visitor.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/host_reorder.h" +#include "cutlass/util/host_uncompress.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "cutlass/epilogue/threadblock/fusion/visitors.hpp" + +#include "helper.h" + +// The code section below describes datatype for input, output matrices and computation between +// elements in input matrices. +using ElementAccumulator = int32_t; // <- data type of accumulator +using ElementComputeEpilogue = ElementAccumulator; // <- data type of epilogue operations +using ElementInputA = int8_t; // <- data type of elements in input matrix A +using ElementInputB = int8_t; // <- data type of elements in input matrix B +using ElementOutput = int32_t; // <- data type of elements in output matrix D + +// The code section below describes matrix layout of input and output matrices. Row Major for +// Matrix A, Column Major for Matrix B and Row Major for Matrix C +using LayoutInputA = cutlass::layout::RowMajor; +using LayoutInputB = cutlass::layout::ColumnMajor; +using LayoutOutput = cutlass::layout::RowMajor; + +// The number of elements per vectorized memory access. +constexpr int AlignmentInputA = 128 / cutlass::sizeof_bits::value; +constexpr int AlignmentInputB = 128 / cutlass::sizeof_bits::value; +constexpr int AlignmentComputeEpilogue = 128 / cutlass::sizeof_bits::value; +constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits::value; + +// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM +using MMAOp = cutlass::arch::OpClassTensorOp; + +// This code section describes CUDA SM architecture number +using SmArch = cutlass::arch::Sm80; + +// This code section describes the tile size a thread block will compute +using ShapeMMAThreadBlock = + cutlass::gemm::GemmShape<128, 128, 128>; // <- threadblock tile M = 128, N = 128, K = 128 +// This code section describes tile size a warp will compute +using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 64, 128>; // <- warp tile M = 64, N = 64, K = 128 +// This code section describes the size of MMA op +using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 64>; // <- MMA Op tile M = 16, N = 8, K = 64 + +// This code section describes how threadblocks are scheduled on GPU +using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; + +using Operator = cutlass::arch::OpMultiplyAddSaturate; + +// Number of pipelines you want to use +constexpr int NumStages = 3; + +constexpr auto NumEVTEpilogueStages = 1; + +using Accum = cutlass::epilogue::threadblock::VisitorAccFetch; + +using BiasTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout< + ShapeMMAThreadBlock, + ShapeMMAWarp, + ElementComputeEpilogue, + AlignmentComputeEpilogue, + NumEVTEpilogueStages>; + +using Bias = cutlass::epilogue::threadblock::VisitorAuxLoad< + BiasTileThreadMap, + ElementComputeEpilogue, + cute::Stride>; + +using ApplyBias = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::plus, ElementComputeEpilogue, ElementComputeEpilogue, + cutlass::FloatRoundStyle::round_to_nearest>; + +using EVTApplyBias = cutlass::epilogue::threadblock::Sm80EVT< + ApplyBias, + Accum, + Bias>; + +using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout< + ShapeMMAThreadBlock, + ShapeMMAWarp, + ElementOutput, + AlignmentOutput, + NumEVTEpilogueStages>; + +using Output = cutlass::epilogue::threadblock::VisitorAuxStore< + OutputTileThreadMap, ElementOutput, + cutlass::FloatRoundStyle::round_to_nearest, + cute::Stride>; + +using EVTOutput = cutlass::epilogue::threadblock::Sm80EVT< + Output, + EVTApplyBias>; + +// Use element type in EVT with the smallest bitwidth as ElementC. +using ElementC = ElementComputeEpilogue; +using LayoutC = LayoutOutput; + +using Gemm = + typename cutlass::gemm::device::SparseGemmWithVisitor< + ElementInputA, LayoutInputA, + ElementInputB, LayoutInputB, + ElementC, LayoutC, + ElementAccumulator, + MMAOp, + SmArch, + ShapeMMAThreadBlock, + ShapeMMAWarp, + ShapeMMAOp, + EVTOutput, + SwizzleThreadBlock, + NumStages, + AlignmentInputA, + AlignmentInputB, + Operator, + NumEVTEpilogueStages>; + +// Data type and layout of meta data matrix E can be inferred from template Gemm. +using ElementInputE = typename Gemm::GemmKernel::ElementE; +using LayoutInputE = cutlass::layout::RowMajor; +using ReorderedLayoutInputE = typename Gemm::GemmKernel::LayoutE; + +// Blow property is defined in include/cutlass/arch/sp_mma_sm80.h +// 50% Sparsity on Ampere +constexpr int kSparse = Gemm::kSparse; +// How many elements of A are covered per ElementE +constexpr int kElementsPerElementE = Gemm::kElementsPerElementE; +// The size of individual meta data +constexpr int kMetaSizeInBits = Gemm::kMetaSizeInBits; + +int run() { + + const int length_m = 512; + const int length_n = 512; + const int length_k = 1024; + + // Create a tuple of problem size for matrix multiplication + cutlass::gemm::GemmCoord problem_size(length_m, length_n, length_k); + + // Initialize tensors using CUTLASS helper functions + cutlass::HostTensor tensor_a( + cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse)); // <- Create matrix A with dimensions M x (K / 2) + cutlass::HostTensor tensor_a_uncompressed( + problem_size.mk()); // <- Create uncompressed matrix A with dimensions M x K for reference computing + + cutlass::HostTensor tensor_b( + problem_size.kn()); // <- Create matrix B with dimensions K x N + cutlass::HostTensor tensor_c( + problem_size.mn()); // <- Create matrix C with dimensions M x N + cutlass::HostTensor tensor_d( + problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from + // CUTLASS kernel + cutlass::HostTensor tensor_ref_d( + problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from + // reference kernel + + // Create matrix E with dimensions M x (K / 2 / kElementsPerElementE). This one is used by reference computing. + cutlass::HostTensor tensor_e( + cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE)); + // Same size as the above. The above one needs to be reordered and stored in this one. + cutlass::HostTensor tensor_e_reordered( + cutlass::make_Coord(problem_size.m(), problem_size.k() / kSparse / kElementsPerElementE)); + + // Fill input and output matrices on host using CUTLASS helper functions + cutlass::reference::host::TensorFillRandomUniform( + tensor_a.host_view(), + 1, + ElementInputA(8), + ElementInputA(-8), + 0); // <- Fill matrix A on host with uniform-distribution random data + cutlass::reference::host::TensorFillRandomUniform( + tensor_b.host_view(), + 1, + ElementInputB(8), + ElementInputB(-8), + 0); // <- Fill matrix B on host with uniform-distribution random data + cutlass::reference::host::TensorFillRandomUniform( + tensor_c.host_view(), + 1, + ElementOutput(8), + ElementOutput(-8), + 0); // <- Fill matrix C on host with uniform-distribution random data + cutlass::reference::host::TensorFillRandomSparseMeta( + tensor_e.host_view(), + 1, + kMetaSizeInBits); // <- Fill matrix E on host with uniform-distribution random meta data + cutlass::reference::host::TensorFill( + tensor_d.host_view()); // <- fill matrix D on host with zeros + cutlass::reference::host::TensorFill( + tensor_ref_d.host_view()); // <- fill matrix D for reference on host with zeros + + // Reorder the meta data matrix so that we can use ldmatrix to load them to tensor core + // instructions. + cutlass::reorder_meta(tensor_e_reordered.host_ref(), tensor_e.host_ref(), + {problem_size.m(), problem_size.n(), + problem_size.k() / kSparse / kElementsPerElementE}); + + // Copy data from host to GPU + tensor_a.sync_device(); + tensor_b.sync_device(); + tensor_c.sync_device(); + tensor_d.sync_device(); + tensor_e_reordered.sync_device(); + tensor_ref_d.sync_device(); + + // Initialize alpha and beta for dot product computation + ElementComputeEpilogue alpha = ElementComputeEpilogue(1); + ElementComputeEpilogue beta = ElementComputeEpilogue(1); + + typename Bias::Arguments bias_arguments{ + tensor_c.device_data(), + ElementComputeEpilogue(0), + {problem_size.n(), cute::_1{}, problem_size.mn().product()} + }; + typename Output::Arguments output_arguments{ + tensor_d.device_data(), + {problem_size.n(), cute::_1{}, problem_size.mn().product()} + }; + typename EVTOutput::Arguments callback_arguments{ + { + {}, // Accum + bias_arguments, // Bias + {} // ApplyBias + }, // EVTApplyBias + output_arguments // Output + }; // EVTOutput + + // Create a tuple of gemm kernel arguments. This is later passed as arguments to launch + // instantiated CUTLASS kernel + typename Gemm::Arguments arguments{problem_size, // <- problem size of matrix multiplication + tensor_a.device_ref(), // <- reference to matrix A on device + tensor_b.device_ref(), // <- reference to matrix B on device + tensor_e_reordered.device_ref(), // <- reference to matrix E on device + callback_arguments}; // <- epilogue arguments + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm_op; + + // Check the problem size is supported or not + cutlass::Status status = gemm_op.can_implement(arguments); + CUTLASS_CHECK(status); + + // Initialize CUTLASS kernel with arguments and workspace pointer + status = gemm_op.initialize(arguments, workspace.get()); + CUTLASS_CHECK(status); + + // Launch initialized CUTLASS kernel + status = gemm_op(); + CUTLASS_CHECK(status); + + // uncompress tensor_a based on meta data tensor_e. We need it for reference computing. + cutlass::uncompress(tensor_a_uncompressed.host_ref(), tensor_a.host_ref(), + tensor_e.host_ref(), problem_size.m(), problem_size.k()); + + // Create instantiation for host reference gemm kernel + cutlass::reference::host::Gemm + gemm_host; + + // Launch host reference gemm kernel + gemm_host(problem_size, + alpha, + tensor_a_uncompressed.host_ref(), + tensor_b.host_ref(), + beta, + tensor_c.host_ref(), + tensor_ref_d.host_ref()); + + // Copy output data from CUTLASS host for comparison + tensor_d.sync_host(); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + bool passed = cutlass::reference::host::TensorEquals( + tensor_d.host_view(), + tensor_ref_d.host_view()); + + std::cout << (passed ? "Passed" : "Failed") << std::endl; + + return (passed ? 0 : -1); +} + +int main() { + + bool notSupported = false; + + // Ampere Sparse Tensor Core operations exposed with mma.sync and ldmatrix are first available + // in CUDA 11.1. + // + // CUTLASS must be compiled with CUDA 11.1 Toolkit to run these examples. + + if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 1))) { + std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.1 Toolkit or later." << std::endl; + notSupported = true; + } + + cudaDeviceProp props; + + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if (props.major * 10 + props.minor < 80) { + std::cerr << "Ampere Tensor Core operations must be run on a machine with compute capability at least 80." + << std::endl; + notSupported = true; + } + + if (notSupported) { + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + return run(); +} diff --git a/examples/16_ampere_tensorop_conv2dfprop/CMakeLists.txt b/examples/16_ampere_tensorop_conv2dfprop/CMakeLists.txt index e3afbb9650..cdc3f11ba3 100644 --- a/examples/16_ampere_tensorop_conv2dfprop/CMakeLists.txt +++ b/examples/16_ampere_tensorop_conv2dfprop/CMakeLists.txt @@ -1,5 +1,5 @@ -# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without diff --git a/examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu b/examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu index 66b0dee500..c0395f5899 100644 --- a/examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu +++ b/examples/16_ampere_tensorop_conv2dfprop/ampere_tensorop_conv2dfprop.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -31,83 +31,181 @@ /** -This example shows how to run convolution kernels using functions and data structures -provided by CUTLASS using tensor cores; which we run on a NVIDIA Ampere GPU. - -Writing a single high performance convolution kernel is hard but do-able. Whereas writing -high performance kernels at scale which works for multiple problem sizes with good abstractions is -really hard. CUTLASS solves this problem by providing simplified abstractions to compose -multiple sections of implicit gemm kernel. When used properly, the kernels can hit peak performance -of GPU easily. - -CUTLASS divides a kernel into hierarchical composable sections. Which means, at each thread, warp -and thread-block level, they compute on their own tile-size with higher level of tile sizes being -composed from lower level ones. Multiple thread-tiles (tile size each thread computes) can be used -to form warp-tiles (tile size each warp computes) and multiple warp tiles can be used to compute -threadblock-tile (tile size computed by a threadblock). - -In thie example, we split variable initialization into -1. Setting up data properties : describes how tensors are laid out in the memory and how the kernel -can view them (logical to physical mapping) -2. Setting up computation properties : describes how the above set tensors will be used to compute -output of convolution. - -First, we setup the data types of the input tensor A, weights' tensor B and output tensor C along -with alpha, beta as the equation for convolution is C = alpha * Conv2dFprop(A, B) + beta * C. In CUTLASS, -the kernels first compute Conv2dFprop(A, B) and leave the rest of the computation to end of the kernel as -alpha * X + beta * C is a simple element-wise operation on X (Conv2dFprop(A, B)) and C. We call this as -epilogue of kernel. Hence, we setup data types for alpha and beta to be equal to -ElementComputeEpilogue = float. We use the data type for elements in input tensor A and B as -cutlass::half_t. We convey this to CUTLASS kernel by initializing template variables ElementAccumulator (float), -ElementComputeEpilogue (float), ElementInputA (cutlass::half_t), ElementInputB (cutlass::half_t), -ElementOutput (float). Communicating just the data type is not enough. As the data is laid out -linearly in memory, we have to convey the layout of tensors. We do that by initializing template -variables LayoutInputA, LayoutInputB and LayoutOutput to TensorNHWC cutlass variable. Next, we setup -rules to comptue alpha * X + beta * C which is called epilogue of the kernel. We initialize template -variable EpilogueOp, which takes the data type of output ElementOutput (float), the number of -elements per vector memory access (8), data type of accumulator (float) and data type of -computation of linear combination (alpha * X + beta * C). - -Now that we setup the properties of data, we have to setup properties of computation. - -Second, we create template variables of tile sizes for thread-block, warp and mma-op to 128x128x64, -64x64x64, 16x8x16 (MxNxK) respectively. When passed to instantiate CUTLASS Implicit GEMM kernel, it -internally deduces the amount of threads needed per thread-block, amount of shared memory, storing -data in bank-conflict free manner, and ton of other variables required to compose, intialize and -launch a high performance Implicit GEMM kernel. This is the beauty of CUTLASS, it relieves developer -from understanding and coding complicated hardware optimizations which can easily go wrong. - -CUTLASS also supports multiple MMA pipelines in a threadblock. What are MMA pipelines? MMA pipelines -constitute the whole process of loading input data from global memory to shared memory, loading data -from shared memory to registers, doing matrix multiplication, store to global memory. The below flow -sequence shows a typical mma multistage pipeline. -(see include/cutlass/conv/threadblock/implicit_gemm_multistage.h) - -tensor in global memory --cp_async--> tile in shared memory --smem loads--> registers ---mma--> registers --global stores--> output to global memory - -NVIDIA Ampere uses `cp_async` to build multistage software pipeline to better hide latencies. - - -There are few more template variables initialized such as, which threadblock tile of output matrix -is done which threadblock launched on an SM, CUDA SM architecture of GPU you want to run on. - -These are all put together to create a template variable which describes CUTLASS Implicit GEMM -kernel using cutlass::conv::device::ImplicitGemm template. - -The next step is to intialize physical data, instantiate and initialize CUTLASS kernel and run it. -We use CUTLASS utilities to initialize, fill, compare tensors as they are simple and doesn't come -in the way of learning CUTLASS. - -Once all the tensors are initialized and filled with data, create arguments tuple to launch CUTLASS -kernel which takes problem size (N = 1, H = 64, W = 64, C = 128), filter size (K = 64, -R = 3, S = 3, C = 128 ), padding, strides, dilation, tensors, alpha, beta and the -important one, split k-dimension factor. Along with that, we query CUTLASS if any scratch-space -memory required by the kernel we instantiated. If yes, we create it and pass it along with other -arguments created to intialize CUTLASS kernel then, the kernel is launched. - -In this example, we later on launch a reference convolution kernel (from CUTLASS utilities) to -compare if the output from CUTLASS kernel is same as the reference implicit GEMM kernel. +This example shows how to run CUTLASS's convolution kernels +based on the Implicit GEMM algorithm, that use the Tensor Cores +on an NVIDIA Ampere GPU. + +Writing a single high-performance convolution kernel is hard enough, +let alone writing kernels that perform well for multiple problem sizes +and use good software abstractions. +CUTLASS provides simplified abstractions +to compose multiple sections of a convolution kernel. +When used properly, the kernels can reach peak GPU performance. + +CUTLASS divides a kernel into hierarchical composable sections +for each level of the GPU hardware hierarchy: +thread, warp, and threadblock. +Each section computes on its own tile shape, +with each higher level's tile shape +being composed from lower-level tile shapes. +Multiple thread tiles (the tile shape each thread computes) +can be used to form warp tiles (the tile shape each warp computes), +and multiple warp tiles can be used to compute threadblock tiles +(the tile shape computed by a threadblock). + +In this example, we split variable initialization into two parts. + +1. Setting up data properties: describes how tensors are laid out in the memory + and how the kernel can view them (logical to physical mapping) + +2. Setting up computation properties: describes how the above tensors + will be used to compute the output of convolution + +We begin by setting up the data types +of all the input and output elements of a convolution. +A convolution computes +C = alpha * Conv2dFprop(A, B) + beta * C, +so we set up data types for the input tensor A, +weights tensor B, output tensor C, +and the scaling factors alpha and beta. +CUTLASS divides the convolution into two parts: +the "mainloop" that computes X = Conv2dFprop(A, B), +and the "epilogue" that computes C = alpha * X + beta * C. +The epilogue is an element-wise operation on X and C. +In this case, it is a linear combination, +but other epilogues are possible. + +In this example, we want + +* the scaling factors alpha and beta to be float, + +* the elements of A and B to be cutlass::half_t + (a 16-bit floating-point type), + +* the elements of C to be float, and + +* intermediate sums to be accumulated in float. + +We convey this to the CUTLASS kernel +by setting the following template parameters. + +* alpha and beta: ElementComputeEpilogue = float + +* Elements of input tensor A: ElementInputA = cutlass::half_t + +* Elements of input tensor B: ElementInputB = cutlass::half_t + +* Elements of output tensor C: ElementOutput = float + +* Accumulation type: ElementAccumulator = float + +Next, we describe the layout of the input and output tensors. +We convey this to the CUTLASS kernel +by setting the following template parameters. + +* Layout of input tensor A: LayoutInputA = TensorNHWC + +* Layout of input tensor B: LayoutInputB = TensorNHWC + +* Layout of output tensor C: LayoutOutput = TensorNHWC + +After that, we set up rules to compute the epilogue. +The epilogue in this case is a simple linear combination +C = alpha * X + beta * C. +Thus, we set the kernel's template parameter EpilogueOp +to LinearCombination. LinearCombination itself +has template parameters: + +* the element type of the output tensor (ElementOutput), + +* the number of elements per vector memory access (8), + +* the data type of the accumulator (ElementAccumulator), + +* and the data type used to compute the linear combination + (ElementComputeEpilogue). + +We then define the tile shapes +that each level of the computation uses. +We define these as types that encode the tile shapes +as compile-time integer values. +Each shape expresses the dimensions M x N x K. +Here, the letters refer to the dimensions +of a matrix-matrix multiply. + +* ThreadblockShape defines the threadblock tile shape + as 128 x 128 x 64. + +* WarpShape defines the warp tile shape as 64 x 64 x 64. + +* InstructionShape defines the MMA + (matrix multiply-accumulate) operation shape + as 16 x 8 x 16. + +These types become template arguments +of the kernel properties type +cutlass::conv::kernel::DefaultConv2dFprop. +The kernel uses these shapes to deduce +the number of threads needed per threadblock, +the required amount of shared memory, +the internal layouts needed to access +shared memory without bank conflicts, +and many other properties that the kernel needs +for good performance. +CUTLASS deduces all these properties automatically, +so that users don't have to. +DefaultConv2dFprop accepts other template parameters +that describe things like the target CUDA SM architecture. + +CUTLASS also supports multiple MMA pipelines in a threadblock. +An MMA pipeline constitutes the whole process +of loading input data from global memory to shared memory, +loading data from shared memory to registers, +doing matrix multiplication, +and storing the result to global memory. +The below flow sequence shows a typical MMA multistage pipeline +(see include/cutlass/conv/threadblock/implicit_gemm_multistage.h). + +tensor in global memory +--cp_async--> +tile in shared memory +--smem loads--> +registers +--mma--> +registers +--global stores--> +output to global memory + +On NVIDIA Ampere, the kernel uses `cp_async` +to build a multistage software pipeline. +This helps it better hide latency. + +At this point, we can define the actual CUTLASS kernel type +as the alias ImplicitGemm, a specialization of +cutlass::conv::device::ImplicitGemmConvolution. +The latter accepts the kernel properties type alias +Conv2dFpropKernel as its one template argument. + +This example then sets up a test problem +and arguments to the kernel. +We use CUTLASS utilities to allocate +the input and output tensors +and fill them with sample input data. +We then create the kernel arguments +as an instance of ImplicitGemm::Arguments. +The arguments include +the problem size (N = 1, H = 64, W = 64, C = 128), +filter size (K = 64, R = 3, S = 3, C = 128), +padding, strides, dilation, tensors, alpha, beta, +and the split k-dimension factor. +We also query CUTLASS if the kernel we instantiated +requires any memory for scratch space. +If yes, we reserve scratch space and pass it along +with other arguments to initialize the CUTLASS kernel. + +After lauching the CUTLASS kernel, this example runs +a reference convolution kernel (from CUTLASS utilities) +to check correctness. */ #include @@ -131,8 +229,8 @@ compare if the output from CUTLASS kernel is same as the reference implicit GEMM #include "helper.h" -// The code section below describes datatype for input, output tensors and computation between -// elements +// Data types for input and output tensors +// and computation between elements using ElementAccumulator = float; // Data type of accumulator using ElementComputeEpilogue = float; // Data type of epilogue computation (alpha, beta) using ElementInputA = cutlass::half_t; // Data type of elements in input tensor @@ -143,39 +241,44 @@ using LayoutInputA = cutlass::layout::TensorNHWC; using LayoutInputB = cutlass::layout::TensorNHWC; using LayoutOutput = cutlass::layout::TensorNHWC; -// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM +// Whether to use tensor cores or regular SIMT cores on GPU SM using MMAOp = cutlass::arch::OpClassTensorOp; -// This code section describes CUDA SM architecture number +// SM architecture number using SmArch = cutlass::arch::Sm80; -// This code section describes the tile size a thread block will compute -using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; // Threadblock tile shape +// Threadblock tile shape +using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 64>; -// This code section describes tile size a warp will compute -using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; // Warp tile shape +// Warp tile shape +using WarpShape = cutlass::gemm::GemmShape<64, 64, 64>; -// This code section describes the size of MMA op -using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; // TensorCore instruction shape +// MMA (Tensor Core instruction, in this case) tile shape +using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; -// This code section describes how threadblocks are scheduled on GPU +// How the kernel schedules threadblocks using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; -// Number of pipelines you want to use +// Number of pipeline stages to use constexpr int NumStages = 3; -// This code section describe iterator algorithm selected is Analytic or Optimized +// Which iterator algorithm to use: Analytic or Optimized static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = cutlass::conv::IteratorAlgorithm::kOptimized; -// This code section describes the epilogue part of the kernel, we use default value +// Is the output packed or strided +// Use kStride if using strided output +static cutlass::conv::StrideSupport const OutputStride = cutlass::conv::StrideSupport::kUnity; + +// The epilogue part of the kernel using EpilogueOp = cutlass::epilogue::thread::LinearCombination< ElementOutput, // Data type of output matrix. - 128 / cutlass::sizeof_bits::value, // The number of elements per vectorized. + 128 / cutlass::sizeof_bits::value, // The number of elements per vectorized // memory access. This becomes the vector width of // math instructions in the epilogue too. ElementAccumulator, // Data type of accumulator ElementComputeEpilogue>; // Data type for alpha/beta in linear combination +// Kernel properties type using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< ElementInputA, LayoutInputA, ElementInputB, LayoutInputB, @@ -190,9 +293,11 @@ using Conv2dFpropKernel = typename cutlass::conv::kernel::DefaultConv2dFprop< SwizzleThreadBlock, NumStages, cutlass::arch::OpMultiplyAdd, - IteratorAlgorithm + IteratorAlgorithm, + OutputStride >::Kernel; +// Type of the actual kernel using ImplicitGemm = cutlass::conv::device::ImplicitGemmConvolution; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -230,7 +335,7 @@ struct Options { beta(0), benchmark(false) { } - // Verify the problem size is compatible with the CUTLASS Convolution implementation. + // Verify that the problem size is compatible with CUTLASS's convolution implementation bool valid() { // @@ -256,7 +361,7 @@ struct Options { return true; } - /// Updates input and filter sizes + /// Update input and filter sizes void update( cutlass::Tensor4DCoord input_size, cutlass::Tensor4DCoord filter_size) { @@ -270,7 +375,7 @@ struct Options { padding.c() = filter_size.w() / 2; } - // Parses the command line + // Parse command-line arguments void parse(int argc, char const **args) { cutlass::CommandLine cmd(argc, args); @@ -302,11 +407,11 @@ struct Options { cmd.get_cmd_line_argument("k", filter_size.n()); cmd.get_cmd_line_argument("r", filter_size.h()); cmd.get_cmd_line_argument("s", filter_size.w()); - filter_size.c() = input_size.c(); + filter_size.c() = input_size.c(); cmd.get_cmd_line_argument("alpha", alpha); cmd.get_cmd_line_argument("beta", beta); - + cmd.get_cmd_line_argument("iterations", iterations); cmd.get_cmd_line_argument("tag", tag); @@ -320,12 +425,12 @@ struct Options { } } - /// Prints the usage statement. + /// Print an explanation of the command-line arguments std::ostream & print_usage(std::ostream &out) const { out << "16_ampere_tensorop_conv2dfprop example\n\n" - << " This example uses Ampere's Tensor Core operators on F16 data types to compute\n" - << " forward convolution on tensors of layout NHWC.\n\n" + << " This example uses Ampere's Tensor Core operators on F16 data types\n" + << " to compute forward convolution on tensors of layout NHWC.\n\n" << "Options:\n\n" << " --help If specified, displays this usage statement.\n\n" << " --n= Input tensor extent N\n" @@ -350,7 +455,7 @@ struct Options { return out; } - + /// Computes the output tensor size (NPQK) cutlass::Tensor4DCoord output_size() const { return cutlass::Tensor4DCoord( @@ -360,19 +465,20 @@ struct Options { filter_size.n()); } - /// Compute performance in GFLOP/s + /// Compute performance in Gflop/s + /// + /// Gflop/s stands for billions (10^9) of + /// floating-point operations per second (Gflop/s). double gflops(double runtime_s) const { // Number of multiply-adds = NPQK * CRS int64_t fmas = output_size().product() * int64_t(filter_size.h() * filter_size.w() * filter_size.c()); - + // Two flops per multiply-add return 2.0 * double(fmas) / double(1.0e9) / runtime_s; } }; -///////////////////////////////////////////////////////////////////////////////////////////////// - struct Result { double runtime_ms; double gflops; @@ -380,14 +486,14 @@ struct Result { cutlass::Status reference_check; cudaError_t error; - Result(): - runtime_ms(0), + Result(): + runtime_ms(0), gflops(0), status(cutlass::Status::kSuccess), reference_check(cutlass::Status::kInvalid), error(cudaSuccess) { } - static std::ostream & print_header(std::ostream &out, Options const &options) { + static std::ostream& print_header(std::ostream &out, Options const &options) { if (!options.tag.empty()) { out << "Name,"; @@ -404,7 +510,7 @@ struct Result { out << options.tag << ","; } - out + out << "conv_" << idx << "," << options.input_size.n() << "," << options.input_size.h() << "," @@ -420,8 +526,6 @@ struct Result { } }; -///////////////////////////////////////////////////////////////////////////////////////////////// - /// Runs one benchmark Result profile_convolution(Options const &options) { @@ -441,7 +545,7 @@ Result profile_convolution(Options const &options) { // Initialize tensors // - // Fill tensor A on host with uniform-distribution random data + // Fill tensor A on host with uniformly distributed random data cutlass::reference::host::TensorFillRandomUniform( tensor_a.host_view(), 1, @@ -449,7 +553,7 @@ Result profile_convolution(Options const &options) { ElementInputA(-8), 0); - // Fill tensor B on host with uniform-distribution random data + // Fill tensor B on host with uniformly distributed random data cutlass::reference::host::TensorFillRandomUniform( tensor_b.host_view(), 1, @@ -457,9 +561,13 @@ Result profile_convolution(Options const &options) { ElementInputB(-8), 0); - // Fill tensor C on host with zeros - cutlass::reference::host::TensorFill( - tensor_c.host_view()); + // Fill tensor C on host with uniformly distributed random data + cutlass::reference::host::TensorFillRandomUniform( + tensor_c.host_view(), + 1, + ElementOutput(7), + ElementOutput(-8), + 0); // Fill tensor D on host with zeros cutlass::reference::host::TensorFill( @@ -486,7 +594,7 @@ Result profile_convolution(Options const &options) { int split_k_slices = 1; // Construct Conv2dProblemSize with user defined output size - cutlass::conv::Conv2dProblemSize problem_size( + cutlass::conv::Conv2dProblemSize problem_size( options.input_size, options.filter_size, options.padding, @@ -497,7 +605,7 @@ Result profile_convolution(Options const &options) { split_k_slices ); - // Construct ImplicitGemm::Argument structure with conv2d + // Construct ImplicitGemm::Argument structure with conv2d // problem size, data pointers, and epilogue values typename ImplicitGemm::Arguments arguments{ problem_size, @@ -535,7 +643,7 @@ Result profile_convolution(Options const &options) { // // Optional reference check // - + if (options.reference_check) { std::cout << "Verification on host...\n"; @@ -548,8 +656,7 @@ Result profile_convolution(Options const &options) { ElementOutput, LayoutOutput, ElementComputeEpilogue, - ElementAccumulator, - cutlass::NumericConverter + ElementAccumulator >( problem_size, tensor_a.host_ref(), @@ -560,7 +667,7 @@ Result profile_convolution(Options const &options) { options.beta ); - // Check if output from CUTLASS kernel and reference kernel are equal or not + // Check if CUTLASS kernel and reference kernel produced the same output tensor_d.sync_host(); bool passed = cutlass::reference::host::TensorEquals( @@ -585,14 +692,14 @@ Result profile_convolution(Options const &options) { std::stringstream ss; ss << "16_ampere_workspace_conv2dfprop_" - << options.input_size.n() << "x" << options.input_size.h() << "x" << options.input_size.w() << "x" << options.input_size.c() + << options.input_size.n() << "x" << options.input_size.h() << "x" << options.input_size.w() << "x" << options.input_size.c() << "_" - << options.filter_size.n() << "x" << options.filter_size.h() << "x" << options.filter_size.w() << "x" << options.filter_size.c() + << options.filter_size.n() << "x" << options.filter_size.h() << "x" << options.filter_size.w() << "x" << options.filter_size.c() << ".dat"; std::ofstream output_workspace(ss.str()); - output_workspace + output_workspace << "Input = \n" << tensor_a.host_view() << "\n\n" << "Filters = \n" << tensor_b.host_view() << "\n\n"; @@ -612,7 +719,7 @@ Result profile_convolution(Options const &options) { if (options.measure_performance) { cudaEvent_t events[2]; - + for (auto & event : events) { result.error = cudaEventCreate(&event); if (result.error != cudaSuccess) { @@ -628,7 +735,7 @@ Result profile_convolution(Options const &options) { return result; } - // Launch a sequence of implicit GEMM operations on the device + // Launch a sequence of implicit GEMM operations on the device. for (int iteration = 0; iteration < options.iterations; ++iteration) { result.status = implicit_gemm_op(); CUTLASS_CHECK(result.status); @@ -648,7 +755,7 @@ Result profile_convolution(Options const &options) { return result; } - // Measure elapsed runtime + // Measure elapsed runtime. float runtime_ms = 0; result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); if (result.error != cudaSuccess) { @@ -656,7 +763,7 @@ Result profile_convolution(Options const &options) { return result; } - // Print average runtime and GFLOPs. + // Print average run time and floating-point throughput (Gflop/s). result.runtime_ms = double(runtime_ms) / double(options.iterations); result.gflops = options.gflops(result.runtime_ms / 1000.0); @@ -669,8 +776,6 @@ Result profile_convolution(Options const &options) { return result; } -///////////////////////////////////////////////////////////////////////////////////////////////// - int main(int argc, char const **args) { bool notSupported = false; @@ -686,7 +791,7 @@ int main(int argc, char const **args) { cudaDeviceProp props; CUDA_CHECK(cudaGetDeviceProperties(&props, 0)); - if (!(props.major > 8 || (props.major == 8 && props.minor >= 0))) { + if (!(props.major >= 8)) { std::cerr << "Ampere Tensor Ops must be run on a machine with compute capability at least 80." << std::endl; notSupported = true; @@ -697,7 +802,7 @@ int main(int argc, char const **args) { } Options options; - + options.parse(argc, args); if (options.help) { @@ -764,5 +869,3 @@ int main(int argc, char const **args) { return 0; } - -///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/17_fprop_per_channel_bias/CMakeLists.txt b/examples/17_fprop_per_channel_bias/CMakeLists.txt index 5ca41a419c..350a27998a 100644 --- a/examples/17_fprop_per_channel_bias/CMakeLists.txt +++ b/examples/17_fprop_per_channel_bias/CMakeLists.txt @@ -1,5 +1,5 @@ -# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without diff --git a/examples/17_fprop_per_channel_bias/fprop_per_channel_bias.cu b/examples/17_fprop_per_channel_bias/fprop_per_channel_bias.cu index 2b6b25c73c..f1658c0fe6 100644 --- a/examples/17_fprop_per_channel_bias/fprop_per_channel_bias.cu +++ b/examples/17_fprop_per_channel_bias/fprop_per_channel_bias.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -290,7 +290,7 @@ int main(int argc, char const **args) { cudaDeviceProp props; CUDA_CHECK(cudaGetDeviceProperties(&props, 0)); - if (!(props.major > 8 || (props.major == 8 && props.minor >= 0))) { + if (!(props.major >= 8)) { std::cerr << "Ampere Tensor Ops must be run on a machine with compute capability at least 80." << std::endl; notSupported = true; diff --git a/examples/18_ampere_fp64_tensorop_affine2_gemm/CMakeLists.txt b/examples/18_ampere_fp64_tensorop_affine2_gemm/CMakeLists.txt index 21a21e90de..5f4541c3fb 100644 --- a/examples/18_ampere_fp64_tensorop_affine2_gemm/CMakeLists.txt +++ b/examples/18_ampere_fp64_tensorop_affine2_gemm/CMakeLists.txt @@ -1,5 +1,5 @@ -# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without diff --git a/examples/18_ampere_fp64_tensorop_affine2_gemm/ampere_fp64_tensorop_affine2_gemm.cu b/examples/18_ampere_fp64_tensorop_affine2_gemm/ampere_fp64_tensorop_affine2_gemm.cu index 62450e2138..1595dd6088 100644 --- a/examples/18_ampere_fp64_tensorop_affine2_gemm/ampere_fp64_tensorop_affine2_gemm.cu +++ b/examples/18_ampere_fp64_tensorop_affine2_gemm/ampere_fp64_tensorop_affine2_gemm.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -326,7 +326,7 @@ int main(int argc, char const **args) { cudaDeviceProp props; CUDA_CHECK(cudaGetDeviceProperties(&props, 0)); - if (!(props.major > 8 || (props.major == 8 && props.minor >= 0))) { + if (!(props.major >= 8)) { std::cerr << "Ampere Tensor Ops must be run on a machine with compute capability at least 80." << std::endl; notSupported = true; diff --git a/examples/19_tensorop_canonical/CMakeLists.txt b/examples/19_tensorop_canonical/CMakeLists.txt index c48ff26008..140f51bf92 100644 --- a/examples/19_tensorop_canonical/CMakeLists.txt +++ b/examples/19_tensorop_canonical/CMakeLists.txt @@ -1,5 +1,5 @@ -# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without diff --git a/examples/19_tensorop_canonical/tensorop_canonical.cu b/examples/19_tensorop_canonical/tensorop_canonical.cu index 15ad17f003..1f0aa93282 100644 --- a/examples/19_tensorop_canonical/tensorop_canonical.cu +++ b/examples/19_tensorop_canonical/tensorop_canonical.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/examples/20_simt_canonical/CMakeLists.txt b/examples/20_simt_canonical/CMakeLists.txt index 25f2969f3c..36dcda7af6 100644 --- a/examples/20_simt_canonical/CMakeLists.txt +++ b/examples/20_simt_canonical/CMakeLists.txt @@ -1,5 +1,5 @@ -# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without diff --git a/examples/20_simt_canonical/simt_canonical.cu b/examples/20_simt_canonical/simt_canonical.cu index d905d4da60..8f2fbc4de0 100644 --- a/examples/20_simt_canonical/simt_canonical.cu +++ b/examples/20_simt_canonical/simt_canonical.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/examples/21_quaternion_gemm/CMakeLists.txt b/examples/21_quaternion_gemm/CMakeLists.txt index 14d71c0796..742a9e7868 100644 --- a/examples/21_quaternion_gemm/CMakeLists.txt +++ b/examples/21_quaternion_gemm/CMakeLists.txt @@ -1,5 +1,5 @@ -# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without diff --git a/examples/21_quaternion_gemm/quaternion_gemm.cu b/examples/21_quaternion_gemm/quaternion_gemm.cu index 95c7b4d7a2..025a9c974e 100644 --- a/examples/21_quaternion_gemm/quaternion_gemm.cu +++ b/examples/21_quaternion_gemm/quaternion_gemm.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/examples/22_quaternion_conv/CMakeLists.txt b/examples/22_quaternion_conv/CMakeLists.txt index 9bfad4ff98..52e1727977 100644 --- a/examples/22_quaternion_conv/CMakeLists.txt +++ b/examples/22_quaternion_conv/CMakeLists.txt @@ -1,5 +1,5 @@ -# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without diff --git a/examples/22_quaternion_conv/quaternion_conv.cu b/examples/22_quaternion_conv/quaternion_conv.cu index 756d465124..bc7173d1ac 100644 --- a/examples/22_quaternion_conv/quaternion_conv.cu +++ b/examples/22_quaternion_conv/quaternion_conv.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -470,8 +470,7 @@ Result profile_convolution(Options const &options) { ElementOutput, LayoutOutput, ElementComputeEpilogue, - ElementAccumulator, - cutlass::NumericConverter + ElementAccumulator >( problem_size, tensor_a.host_ref(), diff --git a/examples/23_ampere_gemm_operand_reduction_fusion/CMakeLists.txt b/examples/23_ampere_gemm_operand_reduction_fusion/CMakeLists.txt index 49d313f4dc..e5b4ec0351 100644 --- a/examples/23_ampere_gemm_operand_reduction_fusion/CMakeLists.txt +++ b/examples/23_ampere_gemm_operand_reduction_fusion/CMakeLists.txt @@ -1,5 +1,5 @@ -# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without @@ -27,10 +27,14 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - +set(TEST_STANDARD --m=1024 --n=1024 --k=1024) +set(TEST_LARGE_PERFCHECK --m=4096 --n=3456 --k=4096 --perf-check) cutlass_example_add_executable( 23_ampere_gemm_operand_reduction_fusion ampere_gemm_operand_reduction_fusion.cu + TEST_COMMAND_OPTIONS + TEST_STANDARD + TEST_LARGE_PERFCHECK ) diff --git a/examples/23_ampere_gemm_operand_reduction_fusion/ampere_gemm_operand_reduction_fusion.cu b/examples/23_ampere_gemm_operand_reduction_fusion/ampere_gemm_operand_reduction_fusion.cu index 41ea3200a1..4e5fca1a03 100644 --- a/examples/23_ampere_gemm_operand_reduction_fusion/ampere_gemm_operand_reduction_fusion.cu +++ b/examples/23_ampere_gemm_operand_reduction_fusion/ampere_gemm_operand_reduction_fusion.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -30,13 +30,13 @@ **************************************************************************************************/ /** -The example demenstrates how to reduce one of the operands of the GEMM along the k-dimension when +The example demonstrates how to reduce one of the operands of the GEMM along the k-dimension when computing GEMM. So the output also contains either a Mx1 or 1XN vector. It only works with Ampere -HMMA 16x8x16 FP16 tensor cores, though it is not difficult to apply to other Turing/Ampere tensor +16x8x16 FP16/BF16 tensor cores, though it is not difficult to apply to other Turing/Ampere tensor core instructions. Most of the reduction is done in gemm/warp level, see gemm/warp/mma_with_reduction_tensor_op.h -A few bit of reduction is done in the epilouge before storing the vector, see +A few bit of reduction is done in the epilogue before storing the vector, see epilogue/threadblock/epilogue_gemm_k_reduction.h */ @@ -45,7 +45,7 @@ epilogue/threadblock/epilogue_gemm_k_reduction.h #include #include "cutlass/cutlass.h" -#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/device/gemm_with_k_reduction.h" #include "cutlass/gemm/kernel/default_gemm_with_k_reduction.h" #include "cutlass/reduction/device/reduce_split_k.h" #include "cutlass/reduction/kernel/reduce_split_k.h" @@ -67,9 +67,9 @@ epilogue/threadblock/epilogue_gemm_k_reduction.h // elements using ElementAccumulator = float; // Data type of accumulator using ElementComputeEpilogue = ElementAccumulator; // Data type of epilogue computation -using ElementInputA = cutlass::half_t; // Data type of elements in input tensor -using ElementInputB = cutlass::half_t; // Data type of elements in input tensor -using ElementOutput = cutlass::half_t; // Data type of elements in output tensor +using ElementInputA = cutlass::bfloat16_t; // Data type of elements in input tensor +using ElementInputB = cutlass::bfloat16_t; // Data type of elements in input tensor +using ElementOutput = cutlass::bfloat16_t; // Data type of elements in output tensor using LayoutInputA = cutlass::layout::ColumnMajor; using LayoutInputB = cutlass::layout::RowMajor; @@ -101,6 +101,12 @@ constexpr int NumStages = 4; // Reduce A or B operand along the K dimension constexpr bool ReduceKForA = true; +// Alignment of A operand +constexpr int AlignmentA = 8; + +// Alignment of B operand +constexpr int AlignmentB = 8; + // This code section describes the epilogue part of the kernel, we use default value using EpilogueOp = cutlass::epilogue::thread::LinearCombination< ElementOutput, // Data type of output matrix. @@ -110,9 +116,9 @@ using EpilogueOp = cutlass::epilogue::thread::LinearCombination< ElementAccumulator, // Data type of accumulator ElementComputeEpilogue>; -using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmWithKReduction< - ElementInputA, LayoutInputA, cutlass::ComplexTransform::kNone, 8, - ElementInputB, LayoutInputB, cutlass::ComplexTransform::kNone, 8, +using Gemm = typename cutlass::gemm::device::GemmWithKReduction< + ElementInputA, LayoutInputA, + ElementInputB, LayoutInputB, ElementOutput, LayoutOutput, ElementAccumulator, MMAOp, @@ -124,13 +130,15 @@ using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmWithKReduction< EpilogueOp, SwizzleThreadBlock, NumStages, - cutlass::arch::OpMultiplyAdd ->::GemmKernel; - -using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + AlignmentA, + AlignmentB, + cutlass::arch::OpMultiplyAdd, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone +>; // Below is the reduction kernel used in the case of parallel split-k -using ReduceGemmSplitKShape = cutlass::MatrixShape<4, 64>;; +using ReduceGemmSplitKShape = cutlass::MatrixShape<4, 64>; using ReduceOp = cutlass::reduction::thread::ReduceAdd< ElementAccumulator, @@ -146,7 +154,7 @@ using ReduceGemmSplitKKernel = cutlass::reduction::kernel::ReduceSplitK< using ReduceGemmSplitK = cutlass::reduction::device::ReduceSplitK; -using ReduceVectorSplitKShape = cutlass::MatrixShape<1, 256>;; +using ReduceVectorSplitKShape = cutlass::MatrixShape<1, 256>; // This code section describes the epilogue part of the kernel, we use default value using DummyEpilogueOp = cutlass::epilogue::thread::LinearCombination< @@ -271,7 +279,7 @@ struct Options { /// Prints the usage statement. std::ostream & print_usage(std::ostream &out) const { - out << "28_ampere_gemm_bias_fusion example\n\n" + out << "23_ampere_operand_gemm_reduction_fusion\n\n" << "Options:\n\n" << " --help If specified, displays this usage statement.\n\n" << " --m= GEMM M\n" @@ -289,7 +297,7 @@ struct Options { << " --tag= String to replicate across the first column in the results table\n"; out << "\n\nExamples:\n\n" - << "$ ./examples/23_ampere_gemm_bias_fusion_example/ampere_gemm_bias_fusion --m=1024 --n=1024 --k=1024 \n\n"; + << "$ ./examples/23_ampere_gemm_operand_reduction_fusion/23_ampere_gemm_operand_reduction_fusion --m=1024 --n=1024 --k=1024 \n\n"; return out; } @@ -368,23 +376,23 @@ Result profile(Options const &options) { // Fill input and output matrices on host using CUTLASS helper functions cutlass::reference::host::TensorFillRandomUniform( tensor_a.host_view(), - 1, - ElementInputA(4), - ElementInputA(-4), + 1997, + ElementInputA(1), + ElementInputA(-1), 0); // <- Fill tensor A on host with uniform-distribution random data cutlass::reference::host::TensorFillRandomUniform( tensor_b.host_view(), - 1, - ElementInputB(4), - ElementInputB(-4), + 2003, + ElementInputB(1), + ElementInputB(-1), 0); // <- Fill tensor B on host with uniform-distribution random data cutlass::reference::host::TensorFillRandomUniform( tensor_c.host_view(), - 1, - ElementOutput(4), - ElementOutput(-4), + 2017, + ElementOutput(1), + ElementOutput(-1), 0); // <- Fill matrix C on host with uniform-distribution random data cutlass::reference::host::TensorFill( tensor_d.host_view()); // <- fill matrix D on host with zeros @@ -418,7 +426,7 @@ Result profile(Options const &options) { // Create a tuple of gemm kernel arguments. This is later passed as arguments to launch // instantiated CUTLASS kernel - typename Gemm::Arguments arguments{ + typename Gemm::Arguments arguments( mode, options.problem_size, batch_count, @@ -437,8 +445,7 @@ Result profile(Options const &options) { tensor_b.layout().stride(0), tensor_c.layout().stride(0), tensor_d.layout().stride(0), - tensor_reduction.layout().stride(0) - }; + tensor_reduction.layout().stride(0)); // Instantiate CUTLASS kernel depending on templates Gemm gemm_op; @@ -507,15 +514,14 @@ Result profile(Options const &options) { cutlass::TensorRef tensor_nullptr_tensorref(nullptr, splitk_vector_layout); - typename ReduceVectorSplitK::Arguments reduce_vector_splitk_arguments{ + typename ReduceVectorSplitK::Arguments reduce_vector_splitk_arguments( cutlass::MatrixCoord(1, reduce_vector_length), batch_count, size_t(reduce_vector_length), workspace_vector_tensorref, tensor_reduction_tensorref, tensor_nullptr_tensorref, - {1.0f, 0.0f} - }; + {1.0f, 0.0f}); ReduceVectorSplitK reduce_vector_splitk_op; @@ -561,7 +567,7 @@ Result profile(Options const &options) { tensor_reduction.sync_host(); - // Compute bias + relu in host code + // Reduce K in host code if (ReduceKForA) { for (int m = 0; m < options.problem_size.m(); ++m) { for (int k = 0; k < options.problem_size.k(); ++k) { @@ -581,7 +587,7 @@ Result profile(Options const &options) { // Check if output from CUTLASS kernel and reference kernel are equal or not bool pass = cutlass::reference::host::TensorEquals(tensor_d.host_view(), tensor_ref_d.host_view()); - + pass &= cutlass::reference::host::TensorEquals(tensor_ref_reduction.host_view(), tensor_reduction.host_view()); @@ -612,10 +618,10 @@ Result profile(Options const &options) { if (options.reference_check) { output_workspace << "Reference D = \n" << tensor_ref_d.host_view() << "\n\n"; - output_workspace << "Reference reduction vector= \n" << tensor_ref_reduction.host_view() << "\n\n"; + output_workspace << "Reference reduction vector = \n" << tensor_ref_reduction.host_view() << "\n\n"; } - output_workspace << "Computed = \n" << tensor_d.host_view() << std::endl; + output_workspace << "Computed D = \n" << tensor_d.host_view() << std::endl; output_workspace << "Computed reduction vector = \n" << tensor_reduction.host_view() << std::endl; std::cout << "Results written to '" << ss.str() << "'." << std::endl; @@ -699,7 +705,7 @@ int main(int argc, char const **args) { cudaDeviceProp props; CUDA_CHECK(cudaGetDeviceProperties(&props, 0)); - if (!(props.major > 8 || (props.major == 8 && props.minor >= 0))) { + if (!(props.major >= 8)) { std::cerr << "Ampere Tensor Ops must be run on a machine with compute capability at least 80." << std::endl; notSupported = true; diff --git a/examples/24_gemm_grouped/CMakeLists.txt b/examples/24_gemm_grouped/CMakeLists.txt index c9f3558e4c..32614a075d 100644 --- a/examples/24_gemm_grouped/CMakeLists.txt +++ b/examples/24_gemm_grouped/CMakeLists.txt @@ -1,5 +1,5 @@ -# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without @@ -31,6 +31,7 @@ cutlass_example_add_executable( 24_gemm_grouped - gemm_grouped.cu + gemm_grouped.cu ) + diff --git a/examples/24_gemm_grouped/gemm_grouped.cu b/examples/24_gemm_grouped/gemm_grouped.cu index a32c80d755..993d554f64 100644 --- a/examples/24_gemm_grouped/gemm_grouped.cu +++ b/examples/24_gemm_grouped/gemm_grouped.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -37,7 +37,7 @@ leading dimensions and problem sizes are stored in arrays in GMEM. This differs from "Batched Array" GEMM because the size of each GEMM problem in the Grouped GEMM - concept may be distinct. + concept may be distinct. This benchmark program initializes a workspace with random problem sizes for a given number of groups. Command line options enable overriding M, N, and/or K dimensions with uniform values to @@ -66,6 +66,7 @@ ///////////////////////////////////////////////////////////////////////////////////////////////// +#include #include #include #include @@ -98,6 +99,7 @@ struct Result { double runtime_ms; + double initialization_time_ms; double gflops; cutlass::Status status; cudaError_t error; @@ -109,11 +111,13 @@ struct Result { Result( double runtime_ms = 0, + double initialization_time_ms = 0, double gflops = 0, cutlass::Status status = cutlass::Status::kSuccess, cudaError_t error = cudaSuccess ): - runtime_ms(runtime_ms), gflops(gflops), status(status), error(error), passed(true) { } + runtime_ms(runtime_ms), initialization_time_ms(initialization_time_ms), gflops(gflops), + status(status), error(error), passed(true) { } }; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -134,6 +138,8 @@ struct Options { bool help; bool error; bool reference_check; + bool profile_initialization; + bool sort_problems; std::vector problem_sizes; @@ -155,21 +161,47 @@ struct Options { std::string output_tag; std::ofstream output_file; + using GroupScheduleMode = cutlass::gemm::kernel::GroupScheduleMode; + std::vector scheduler_modes; + + std::unordered_map + str_to_scheduler_mode = { + {"kDeviceOnly", GroupScheduleMode::kDeviceOnly}, + {"kHostPrecompute", GroupScheduleMode::kHostPrecompute} + }; + + struct GroupScheduleModeHash { + size_t operator()(GroupScheduleMode m) const { + return static_cast(m); + } + }; + + std::unordered_map + scheduler_mode_to_str = { + {GroupScheduleMode::kDeviceOnly, "kDeviceOnly"}, + {GroupScheduleMode::kHostPrecompute, "kHostPrecompute"} + }; + + std::vector all_scheduler_modes = {GroupScheduleMode::kDeviceOnly, GroupScheduleMode::kHostPrecompute}; + // // Methods - // + // Options(): help(false), error(false), alignment(8), reference_check(true), + profile_initialization(false), + sort_problems(false), problem_count(15), iterations(20), cuda_streams(0), verbose(false), alpha(1), - beta() + beta(), + scheduler_modes({GroupScheduleMode::kDeviceOnly}) { } // Parses the command line @@ -184,13 +216,40 @@ struct Options { cmd.get_cmd_line_argument("alignment", alignment, 8); cmd.get_cmd_line_argument("groups", problem_count, 15); cmd.get_cmd_line_argument("alpha", alpha, 1.0f); - cmd.get_cmd_line_argument("beta", beta, 0.0f); + cmd.get_cmd_line_argument("beta", beta, 0.0f); cmd.get_cmd_line_argument("iterations", iterations, 20); cmd.get_cmd_line_argument("streams", cuda_streams, 0); cmd.get_cmd_line_argument("verbose", verbose, false); cmd.get_cmd_line_argument("reference-check", reference_check, true); + cmd.get_cmd_line_argument("profile-initialization", profile_initialization, false); + cmd.get_cmd_line_argument("sort-problems", sort_problems, false); cmd.get_cmd_line_argument("benchmark", benchmark_path); + std::vector scheduler_mode_strs; + cmd.get_cmd_line_arguments("scheduler-modes", scheduler_mode_strs); + + if (!scheduler_mode_strs.empty()) { + scheduler_modes.clear(); + if (scheduler_mode_strs.size() == 1 && scheduler_mode_strs[0] == "all") { + scheduler_modes = all_scheduler_modes; + } else { + for (std::string precomp_str : scheduler_mode_strs) { + auto it = str_to_scheduler_mode.find(precomp_str); + if (it != str_to_scheduler_mode.end()) { + scheduler_modes.push_back(it->second); + } else if (precomp_str == "all") { + std::cerr << "Flag --scheduler-modes=all must not contain other scheduler modes in list." << std::endl; + error = true; + return; + } else { + std::cerr << "Unrecognized scheduler mode '" << precomp_str << "'" << std::endl; + error = true; + return; + } + } + } + } + std::string output_path; cmd.get_cmd_line_argument("tag", output_tag); cmd.get_cmd_line_argument("output_file", output_path); @@ -314,6 +373,8 @@ struct Options { /// Post processes the problems void bin_problems() { + problem_bins.clear(); + problem_count = int(problem_sizes.size()); // @@ -340,19 +401,22 @@ struct Options { << " 'group' may compute a unique problem size. Problem sizes and pointers to matrices are both stored\n" << " in device Global Memory and loaded by the kernel.\n\n" << "Options:\n\n" - << " --help If specified, displays this usage statement.\n\n" - << " --benchmark= Executes a benchmark problem size.\n" - << " --output_file= Path to a CSV file to output results. If it exists already, results are appended.\n" - << " --tag= String tag to prepend to the CSV file.\n" - << " --groups= Number of individual GEMM problems (default: --groups=15)\n" - << " --m= Sets the M dimension for all groups. Otherwise, it is selected randomly\n" - << " --n= Sets the N dimension for all groups. Otherwise, it is selected randomly\n" - << " --k= Sets the K dimension for all groups. Otherwise, it is selected randomly\n" - << " --alpha= Epilogue scalar alpha (real part)\n" - << " --beta= Epilogue scalar beta (real part)\n\n" - << " --iterations= Number of profiling iterations to perform.\n" - << " --reference-check= If true, performs reference check.\n" - << " --verbose= If true, prints problem sizes and batching structure.\n"; + << " --help If specified, displays this usage statement.\n\n" + << " --benchmark= Executes a benchmark problem size.\n" + << " --output_file= Path to a CSV file to output results. If it exists already, results are appended.\n" + << " --tag= String tag to prepend to the CSV file.\n" + << " --groups= Number of individual GEMM problems (default: --groups=15)\n" + << " --m= Sets the M dimension for all groups. Otherwise, it is selected randomly\n" + << " --n= Sets the N dimension for all groups. Otherwise, it is selected randomly\n" + << " --k= Sets the K dimension for all groups. Otherwise, it is selected randomly\n" + << " --alpha= Epilogue scalar alpha (real part)\n" + << " --beta= Epilogue scalar beta (real part)\n" + << " --scheduler-modes= List of scheduler modes to be profile for grouped GEMM scheduler (default: --scheduler_modes=kDeviceOnly)\n" + << " --iterations= Number of profiling iterations to perform.\n" + << " --reference-check= If true, performs reference check.\n" + << " --verbose= If true, prints problem sizes and batching structure.\n" + << " --profile-initialization= If true, profiles the device-level kernel's initialization.\n" + << " --sort-problems= If true, sorts problem sizes in descending order of GEMM-K dimension.\n"; out << "\n\nExamples:\n\n" @@ -365,6 +429,12 @@ struct Options { << "# Runs a grouped GEMM that is equivalent to a batched GEMM\n" << "$ ./examples/24_gemm_grouped/24_gemm_grouped --groups=100 --m=2048 --n=1024 --k=1024 --verbose=true\n\n" + << "# Runs a grouped GEMM with each different scheduler mode\n" + << "$ ./examples/24_gemm_grouped/24_gemm_grouped --scheduler-modes=all\n\n" + + << "# Runs a grouped GEMM with each different scheduler mode and profiles host-side initialization time\n" + << "$ ./examples/24_gemm_grouped/24_gemm_grouped --scheduler-modes=all --profile-initialization=true\n\n" + << "# Runs a grouped GEMM problem given an externally supplied benchmark file. This is a text file in which\n" << "# Each line contains a unique group index and an MxNxK triple indicating problemsize.\n" << "#\n" @@ -385,13 +455,13 @@ struct Options { /// Compute performance in GFLOP/s double gflops(double runtime_s) const { - // Number of real-valued multiply-adds + // Number of real-valued multiply-adds int64_t fmas = int64_t(); for (auto const & problem : problem_sizes) { fmas += problem.product(); } - + // Two flops per multiply-add return 2.0 * double(fmas) / double(1.0e9) / runtime_s; } @@ -399,10 +469,9 @@ struct Options { /////////////////////////////////////////////////////////////////////////////////////////////////// -template -class TestbedGrouped { +template +class BaseTestbed { public: - // // Type definitions // @@ -421,8 +490,6 @@ public: using MatrixCoord = typename LayoutC::TensorCoord; -private: - // // Data members // @@ -462,13 +529,7 @@ private: cutlass::DeviceAllocation ptr_C; cutlass::DeviceAllocation ptr_D; -public: - - // - // Methods - // - - TestbedGrouped( + BaseTestbed( Options &options_, cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, @@ -481,13 +542,11 @@ public: return options.problem_count; } -private: - /// Helper to initialize a tensor view template - void initialize_tensor_( + void initialize_tensor( Element *ptr, - size_t capacity, + size_t capacity, cutlass::Distribution::Kind dist_kind, uint32_t seed) { @@ -519,7 +578,7 @@ private: cutlass::reference::device::BlockFillRandomUniform( ptr, capacity, seed, scope_max, scope_min, 0); - } + } else if (dist_kind == cutlass::Distribution::Gaussian) { cutlass::reference::device::BlockFillRandomGaussian( @@ -530,7 +589,7 @@ private: // Fill with increasing elements cutlass::reference::device::BlockFillSequential( ptr, capacity, Element(1), Element()); - } + } else { // Fill with all 1s @@ -539,65 +598,13 @@ private: } } - /// Verbose printing of problem sizes - void print_problem_sizes_() { - - // Print groups - std::cout << problem_count() << " groups:\n"; - - int32_t idx = 0; - int64_t total_tiles = 0; - - for (auto const & problem : options.problem_sizes) { - - int tiles = - ((problem.m() + Gemm::ThreadblockShape::kM - 1) / Gemm::ThreadblockShape::kM) * - ((problem.n() + Gemm::ThreadblockShape::kN - 1) / Gemm::ThreadblockShape::kN); - - total_tiles += tiles; - - std::cout << " [" << idx << "]: " - << problem.m() << "-by-" << problem.n() << "-by-" << problem.k() - << " (" << tiles << " threadblock tiles)" << "\n"; - - ++idx; - } - - // Print batched GEMM equivalent - size_t bin_idx = 0; - size_t problem_count_check = 0; - std::cout << "\nConventionally executed as " << options.problem_bins.size() << " batched GEMMs:\n"; - for (auto const & bin : options.problem_bins) { - - std::cout << " [" << bin_idx << "]: " - << bin.first.m() << "-by-" << bin.first.n() << "-by-" << bin.first.k() - << ", batch count: " << bin.second.size() << "\n"; - - ++bin_idx; - problem_count_check += bin.second.size(); - } - - if (problem_count_check != problem_count()) { - std::cout << "\n***\nERROR in BINNING LOGIC!\n***\n" << std::endl; - } - } - - /// Initializes data structures - void initialize_() { - - // - // Choose random problem sizes - // - - // construct a few problems of random sizes - srand(seed); - + /// Allocates device-side data + void allocate() { int64_t total_elements_A = 0; int64_t total_elements_B = 0; int64_t total_elements_C = 0; int64_t total_elements_D = 0; - lda_host.resize(problem_count()); ldb_host.resize(problem_count()); ldc_host.resize(problem_count()); @@ -628,14 +635,22 @@ private: total_elements_D += elements_D; } - problem_sizes_device.reset(problem_count()); - problem_sizes_device.copy_from_host(options.problem_sizes.data()); - lda.reset(problem_count()); ldb.reset(problem_count()); ldc.reset(problem_count()); ldd.reset(problem_count()); + block_A.reset(total_elements_A); + block_B.reset(total_elements_B); + block_C.reset(total_elements_C); + block_D.reset(total_elements_D); + } + + /// Initializes device-side data + void initialize() { + problem_sizes_device.reset(problem_count()); + problem_sizes_device.copy_from_host(options.problem_sizes.data()); + lda.copy_from_host(lda_host.data()); ldb.copy_from_host(ldb_host.data()); ldc.copy_from_host(ldc_host.data()); @@ -645,11 +660,6 @@ private: // Assign pointers // - block_A.reset(total_elements_A); - block_B.reset(total_elements_B); - block_C.reset(total_elements_C); - block_D.reset(total_elements_D); - std::vector ptr_A_host(problem_count()); std::vector ptr_B_host(problem_count()); std::vector ptr_C_host(problem_count()); @@ -664,13 +674,13 @@ private: ptr_A.reset(problem_count()); ptr_A.copy_from_host(ptr_A_host.data()); - + ptr_B.reset(problem_count()); ptr_B.copy_from_host(ptr_B_host.data()); - + ptr_C.reset(problem_count()); ptr_C.copy_from_host(ptr_C_host.data()); - + ptr_D.reset(problem_count()); ptr_D.copy_from_host(ptr_D_host.data()); @@ -678,16 +688,16 @@ private: // Initialize the problems of the workspace // - initialize_tensor_(block_A.get(), total_elements_A, init_A, seed * 2021); - initialize_tensor_(block_B.get(), total_elements_B, init_B, seed * 2022); - initialize_tensor_(block_C.get(), total_elements_C, init_C, seed * 2023); + initialize_tensor(block_A.get(), block_A.size(), init_A, seed * 2021); + initialize_tensor(block_B.get(), block_B.size(), init_B, seed * 2022); + initialize_tensor(block_C.get(), block_C.size(), init_C, seed * 2023); cutlass::reference::device::BlockFillSequential( - block_D.get(), total_elements_D, ElementC(), ElementC()); + block_D.get(), block_D.size(), ElementC(), ElementC()); } /// Verifies the result is a GEMM - bool verify_() { + bool verify() { bool passed = true; @@ -702,7 +712,7 @@ private: MatrixCoord extent_A{problem.m(), problem.k()}; MatrixCoord extent_B{problem.k(), problem.n()}; MatrixCoord extent_C{problem.m(), problem.n()}; - + cutlass::TensorView view_A(block_A.get() + offset_A.at(i), layout_A, extent_A); cutlass::TensorView view_B(block_B.get() + offset_B.at(i), layout_B, extent_B); cutlass::TensorView view_C(block_C.get() + offset_C.at(i), layout_C, extent_C); @@ -714,18 +724,18 @@ private: cutlass::reference::device::GemmComplex< ElementA, LayoutA, ElementB, LayoutB, - ElementC, LayoutC, + ElementC, LayoutC, ElementCompute, ElementAccumulator >( problem, - options.alpha, + options.alpha, view_A, Gemm::kTransformA, view_B, Gemm::kTransformB, - options.beta, - view_C, - view_Ref_device, + options.beta, + view_C, + view_Ref_device, ElementAccumulator(0) ); @@ -738,7 +748,7 @@ private: cutlass::TensorView view_D( matrix_D.data(), layout_D, extent_C); cutlass::TensorView view_Ref(matrix_Ref.data(), layout_D, extent_C); - + // Reference check passed = cutlass::reference::host::TensorEquals(view_D, view_Ref); @@ -751,227 +761,62 @@ private: return passed; } -public: +}; - /// Returns the number of threadblocks to launch if the kernel can run on the target - /// device. Otherwise, returns zero. - int sufficient() const { - cudaDeviceProp properties; - int device_idx; - cudaError_t result = cudaGetDevice(&device_idx); +template +class TestbedBatched : BaseTestbed { +public: + TestbedBatched( + Options &options_, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint32_t seed_ = 3080 + ): BaseTestbed(options_, init_A_, init_B_, init_C_, seed_) {} - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDevice() API call failed."); - } + void print_problem_sizes() { + std::cout << std::endl; + size_t bin_idx = 0; + size_t problem_count_check = 0; + std::cout << "Conventionally executed as " << this->options.problem_bins.size() << " batched GEMMs:\n"; + for (auto const & bin : this->options.problem_bins) { - result = cudaGetDeviceProperties(&properties, device_idx); + std::cout << " [" << bin_idx << "]: " + << bin.first.m() << "-by-" << bin.first.n() << "-by-" << bin.first.k() + << ", batch count: " << bin.second.size() << "\n"; - if (result != cudaSuccess) { - throw std::runtime_error("cudaGetDeviceProperties() failed"); + ++bin_idx; + problem_count_check += bin.second.size(); } - int occupancy = Gemm::maximum_active_blocks(); - - return properties.multiProcessorCount * occupancy; + if (problem_count_check != size_t(this->problem_count())) { + std::cout << "\n***\nERROR in BINNING LOGIC!\n***\n" << std::endl; + } + std::cout << std::endl; } - - /// Executes a Grouped GEMM kernel and measures runtime. - Result profile_grouped() { + /// Executes a batched kernel and measures runtime + Result profile() { + std::cout << "Batched GEMM:\n" + << "====================================================" << std::endl; Result result; - - int threadblock_count = sufficient(); - - // Early exit - if (!threadblock_count) { - std::cout << "Active CUDA device lacks hardware resources to run CUTLASS Grouped GEMM kernel." << std::endl; - return result; - } - - if (options.verbose) { - print_problem_sizes_(); - } - result.passed = false; // Initialize the problem - initialize_(); - - // Configure the GEMM arguments - typename EpilogueOutputOp::Params epilogue_op(options.alpha, options.beta); + this->allocate(); + this->initialize(); - // Configure GEMM arguments - typename Gemm::Arguments args( - problem_sizes_device.get(), - problem_count(), - threadblock_count, - epilogue_op, - ptr_A.get(), - ptr_B.get(), - ptr_C.get(), - ptr_D.get(), - lda.get(), - ldb.get(), - ldc.get(), - ldd.get() - ); - - // Initialize the GEMM object - Gemm gemm; - - result.status = gemm.initialize(args); - - if (result.status != cutlass::Status::kSuccess) { - std::cerr << "Failed to initialize CUTLASS Grouped GEMM kernel." << std::endl; - return result; - } - - // Run the grouped GEMM object - result.status = gemm.run(); - - if (result.status != cutlass::Status::kSuccess) { - std::cerr << "Failed to run CUTLASS Grouped GEMM kernel." << std::endl; - return result; - } - - // Wait for completion - result.error = cudaDeviceSynchronize(); - - if (result.error != cudaSuccess) { - std::cerr << "Kernel execution error: " << cudaGetErrorString(result.error); - return result; - } - - // - // Verify correctness - // - result.passed = true; - - if (options.reference_check) { - result.passed = verify_(); - } - - // - // Warm-up run of the grouped GEMM object - // - result.status = gemm.run(); - - if (result.status != cutlass::Status::kSuccess) { - std::cerr << "Failed to run CUTLASS Grouped GEMM kernel." << std::endl; - return result; - } - - // - // Construct events - // - - cudaEvent_t events[2]; - - for (auto & event : events) { - result.error = cudaEventCreate(&event); - if (result.error != cudaSuccess) { - std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; - return -1; - } - } - - // Record an event at the start of a series of GEMM operations - result.error = cudaEventRecord(events[0]); - if (result.error != cudaSuccess) { - std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; - return result; - } - - // - // Run profiling loop - // - - for (int iter = 0; iter < options.iterations; ++iter) { - gemm(); - } - - // - // Stop profiling loop - // - - // Record an event when the GEMM operations have been launched. - result.error = cudaEventRecord(events[1]); - if (result.error != cudaSuccess) { - std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; - return result; - } - - // Wait for work on the device to complete. - result.error = cudaEventSynchronize(events[1]); - if (result.error != cudaSuccess) { - std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; - return result; - } - - // Measure elapsed runtime - float runtime_ms = 0; - result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); - if (result.error != cudaSuccess) { - std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; - return result; + if (this->options.verbose) { + print_problem_sizes(); } - // Compute average runtime and GFLOPs. - result.runtime_ms = double(runtime_ms) / double(options.iterations); - result.gflops = options.gflops(result.runtime_ms / 1000.0); - - // - // Cleanup - // - - for (auto event : events) { - (void)cudaEventDestroy(event); - } - - int32_t idx = 0; - int64_t total_tiles = 0; - - for (auto const & problem : options.problem_sizes) { - - int tiles = - ((problem.m() + Gemm::ThreadblockShape::kM - 1) / Gemm::ThreadblockShape::kM) * - ((problem.n() + Gemm::ThreadblockShape::kN - 1) / Gemm::ThreadblockShape::kN); - - total_tiles += tiles; - ++idx; - } - - std::cout << std::endl; - std::cout << "Grouped GEMM (CUTLASS):\n" - << "====================================================" << std::endl; - - std::cout << " " << total_tiles << " total threadblock tiles." << std::endl; - - std::cout << std::endl; - std::cout << " " << "Grouped Runtime: " << result.runtime_ms << " ms" << std::endl; - std::cout << " " << "Grouped GFLOPs: " << result.gflops << std::endl; - - if (options.output_file.good()) { - options.output_file << options.output_tag << ",CUTLASS,grouped," - << problem_count() << "," << result.runtime_ms << "," << result.gflops << std::endl; - } - - return result; - } - - /// Executes a conventional batched GEMM kernel. - Result profile_batched() { - - Result result; - result.passed = false; - // // Prepare batched GEMM environment // - int32_t effective_streams = (options.cuda_streams ? options.cuda_streams : 1); + int32_t effective_streams = (this->options.cuda_streams ? this->options.cuda_streams : 1); // Array of leading dimensions used by batched GEMM calls std::vector bin_problem_sizes; @@ -985,15 +830,15 @@ public: std::vector ptr_B_batched_host; std::vector ptr_C_batched_host; - for (auto const & bin : options.problem_bins) { + for (auto const & bin : this->options.problem_bins) { int first_idx = bin.second.front(); - - bin_problem_sizes.push_back(options.problem_sizes.at(first_idx)); + + bin_problem_sizes.push_back(this->options.problem_sizes.at(first_idx)); bin_count.push_back(int32_t(bin.second.size())); - bin_ldm_A.push_back(static_cast(lda_host.at(first_idx))); - bin_ldm_B.push_back(static_cast(ldb_host.at(first_idx))); - bin_ldm_C.push_back(static_cast(ldc_host.at(first_idx))); + bin_ldm_A.push_back(static_cast(this->lda_host.at(first_idx))); + bin_ldm_B.push_back(static_cast(this->ldb_host.at(first_idx))); + bin_ldm_C.push_back(static_cast(this->ldc_host.at(first_idx))); if (ptr_A_batched_host.size() % 2) { ptr_A_batched_host.push_back(nullptr); @@ -1005,29 +850,29 @@ public: for (int idx : bin.second) { - if (bin_problem_sizes.back() != options.problem_sizes.at(idx)) { + if (bin_problem_sizes.back() != this->options.problem_sizes.at(idx)) { std::cerr << "Error - failed to group problems.\n"; return result; } - if (bin_ldm_A.back() != lda_host.at(idx)) { + if (bin_ldm_A.back() != this->lda_host.at(idx)) { std::cerr << "Error - failed to group problems.\n"; return result; } - if (bin_ldm_B.back() != ldb_host.at(idx)) { + if (bin_ldm_B.back() != this->ldb_host.at(idx)) { std::cerr << "Error - failed to group problems.\n"; return result; } - if (bin_ldm_C.back() != ldc_host.at(idx)) { + if (bin_ldm_C.back() != this->ldc_host.at(idx)) { std::cerr << "Error - failed to group problems.\n"; return result; } - ptr_A_batched_host.push_back(block_A.get() + offset_A.at(idx)); - ptr_B_batched_host.push_back(block_B.get() + offset_B.at(idx)); - ptr_C_batched_host.push_back(block_D.get() + offset_C.at(idx)); + ptr_A_batched_host.push_back(this->block_A.get() + this->offset_A.at(idx)); + ptr_B_batched_host.push_back(this->block_B.get() + this->offset_B.at(idx)); + ptr_C_batched_host.push_back(this->block_D.get() + this->offset_C.at(idx)); } } @@ -1048,15 +893,14 @@ public: // Create CUDA streams to maximize concurrency of batched-array GEMM kernels // std::vector cuda_streams; - char const *provider = "CUTLASS"; // // Warmup run // - if (options.cuda_streams) { - for (int i = 0; i < options.cuda_streams; ++i) { + if (this->options.cuda_streams) { + for (int i = 0; i < this->options.cuda_streams; ++i) { cudaStream_t stream; result.error = cudaStreamCreate(&stream); @@ -1074,7 +918,7 @@ public: } // Use 'D' for the in/out workspace - block_D.copy_from_device(block_C.get()); + this->block_D.copy_from_device(this->block_C.get()); for (int bin_idx = 0; bin_idx < int32_t(bin_problem_sizes.size()); ++bin_idx) { @@ -1094,9 +938,9 @@ public: // // Configure the GEMM arguments - typename EpilogueOutputOp::Params epilogue_op(options.alpha, options.beta); + typename Gemm::EpilogueOutputOp::Params epilogue_op(this->options.alpha, this->options.beta); - typename GemmBatched::Arguments arguments{ + typename Gemm::Arguments arguments{ cutlass::gemm::GemmUniversalMode::kArray, problem, batch_count, @@ -1115,7 +959,7 @@ public: int64_t(ldc) }; - GemmBatched gemm_op; + Gemm gemm_op; cutlass::Status status = gemm_op.initialize(arguments); @@ -1130,7 +974,7 @@ public: std::cerr << "CUTLASS error on line " << __LINE__ << std::endl; return result; } - + } // @@ -1182,8 +1026,8 @@ public: int last_stream_idx = 0; - for (int iter = 0; iter < options.iterations; ++iter) { - + for (int iter = 0; iter < this->options.iterations; ++iter) { + for (int bin_idx = 0; bin_idx < int32_t(bin_problem_sizes.size()); ++bin_idx) { cutlass::gemm::GemmCoord const & problem = bin_problem_sizes[bin_idx]; @@ -1204,9 +1048,9 @@ public: // // Configure the GEMM arguments - typename EpilogueOutputOp::Params epilogue_op(options.alpha, options.beta); + typename Gemm::EpilogueOutputOp::Params epilogue_op(this->options.alpha, this->options.beta); - typename GemmBatched::Arguments arguments{ + typename Gemm::Arguments arguments{ cutlass::gemm::GemmUniversalMode::kArray, problem, batch_count, @@ -1225,7 +1069,7 @@ public: int64_t(ldc) }; - GemmBatched gemm_op; + Gemm gemm_op; cutlass::Status status = gemm_op.initialize(arguments); @@ -1254,7 +1098,7 @@ public: std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; return result; } - + // // Wait for work to be completed // @@ -1266,15 +1110,263 @@ public: return result; } - // Wait for work on the device to complete. - result.error = cudaEventSynchronize(events[1]); + // Measure elapsed runtime + float runtime_ms = 0; + result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); if (result.error != cudaSuccess) { - std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; + std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Compute average runtime and GFLOPs. + result.runtime_ms = double(runtime_ms) / double(this->options.iterations); + result.gflops = this->options.gflops(result.runtime_ms / 1000.0); + + // + // Cleanup + // + + for (auto event : events) { + (void)cudaEventDestroy(event); + } + + for (auto stream : cuda_streams) { + if (stream) { + (void)cudaStreamDestroy(stream); + } + } + + std::cout << " " << this->options.problem_bins.size() << " batched GEMMs launched" << std::endl; + std::cout << std::endl; + std::cout << " " << "Batched Runtime: " << result.runtime_ms << " ms" << std::endl; + std::cout << " " << "Batched GFLOPs: " << result.gflops << std::endl; + + std::string provider = "CUTLASS"; + + if (this->options.output_file.good()) { + this->options.output_file << this->options.output_tag << "," << provider << ",batched," + << this->options.problem_count << "," << result.runtime_ms << "," << result.gflops << std::endl; + } + + result.passed = true; + return result; + } +}; + +template +class TestbedGrouped : BaseTestbed { +public: + TestbedGrouped( + Options &options_, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint32_t seed_ = 3080 + ): BaseTestbed(options_, init_A_, init_B_, init_C_, seed_) {} + + // Redefine GEMM with different GroupScheduleMode_ + using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< + typename Gemm_::ElementA, + typename Gemm_::LayoutA, + Gemm_::kTransformA, + Gemm_::kAlignmentA, + typename Gemm_::ElementB, + typename Gemm_::LayoutB, + Gemm_::kTransformB, + Gemm_::kAlignmentB, + typename Gemm_::ElementC, + typename Gemm_::LayoutC, + typename Gemm_::ElementAccumulator, + typename Gemm_::OperatorClass, + typename Gemm_::ArchTag, + typename Gemm_::ThreadblockShape, + typename Gemm_::WarpShape, + typename Gemm_::InstructionShape, + typename Gemm_::EpilogueOutputOp, + typename Gemm_::ThreadblockSwizzle, + Gemm_::kStages, + GroupScheduleMode_>::GemmKernel; + + using Gemm = cutlass::gemm::device::GemmGrouped; + + /// Verbose printing of problem sizes + void print_problem_sizes() { + std::cout << std::endl; + + // Print groups + std::cout << this->problem_count() << " groups:\n"; + + int32_t idx = 0; + int64_t total_tiles = 0; + + for (auto const & problem : this->options.problem_sizes) { + int tiles = Gemm::problem_tile_count(problem); + total_tiles += tiles; + + std::cout << " [" << idx << "]: " + << problem.m() << "-by-" << problem.n() << "-by-" << problem.k() + << " (" << tiles << " threadblock tiles)" << "\n"; + + ++idx; + } + std::cout << std::endl; + } + + /// Sort problems in descending order of problem-K dimension + void sort_problems() { + Gemm::sort_problems(this->options.problem_count, + this->options.problem_sizes.data(), + this->lda_host.data(), + this->ldb_host.data(), + this->ldc_host.data(), + this->ldd_host.data(), + this->offset_A.data(), + this->offset_B.data(), + this->offset_C.data(), + this->offset_D.data()); + } + + /// Executes a grouped kernel and measures runtime + Result profile() { + std::string sched_mode = this->options.scheduler_mode_to_str.find(GroupScheduleMode_)->second; + + std::cout << std::endl; + std::cout << "Grouped GEMM (CUTLASS) with mode " << sched_mode << ":\n" + << "====================================================" << std::endl; + + Result result; + + int threadblock_count = Gemm::sufficient(this->options.problem_sizes.data(), this->options.problem_count); + + // Early exit + if (!threadblock_count) { + std::cout << "Active CUDA device lacks hardware resources to run CUTLASS Grouped GEMM kernel." << std::endl; + return result; + } + + result.passed = false; + + // Initialize the problem + this->allocate(); + if (this->options.sort_problems) { + sort_problems(); + } + this->initialize(); + + if (this->options.verbose) { + print_problem_sizes(); + } + + // Configure the GEMM arguments + typename Gemm::EpilogueOutputOp::Params epilogue_op(this->options.alpha, this->options.beta); + + // Configure GEMM arguments + typename Gemm::Arguments args( + this->problem_sizes_device.get(), + this->problem_count(), + threadblock_count, + epilogue_op, + this->ptr_A.get(), + this->ptr_B.get(), + this->ptr_C.get(), + this->ptr_D.get(), + this->lda.get(), + this->ldb.get(), + this->ldc.get(), + this->ldd.get(), + this->options.problem_sizes.data() + ); + + // Initialize the GEMM object + Gemm gemm; + + size_t workspace_size = gemm.get_workspace_size(args); + cutlass::DeviceAllocation workspace(workspace_size); + + result.status = gemm.initialize(args, workspace.get()); + + if (result.status != cutlass::Status::kSuccess) { + std::cerr << "Failed to initialize CUTLASS Grouped GEMM kernel." << std::endl; + return result; + } + + // Run the grouped GEMM object + result.status = gemm.run(); + + if (result.status != cutlass::Status::kSuccess) { + std::cerr << "Failed to run CUTLASS Grouped GEMM kernel." << std::endl; + return result; + } + + // Wait for completion + result.error = cudaDeviceSynchronize(); + + if (result.error != cudaSuccess) { + std::cerr << "Kernel execution error: " << cudaGetErrorString(result.error); + return result; + } + + // + // Verify correctness + // + result.passed = true; + + if (this->options.reference_check) { + result.passed = this->verify(); + } + + // + // Warm-up run of the grouped GEMM object + // + result.status = gemm.run(); + + if (result.status != cutlass::Status::kSuccess) { + std::cerr << "Failed to run CUTLASS Grouped GEMM kernel." << std::endl; + return result; + } + + // + // Construct events + // + + cudaEvent_t events[2]; + + for (auto & event : events) { + result.error = cudaEventCreate(&event); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; + return -1; + } + } + + // Record an event at the start of a series of GEMM operations + result.error = cudaEventRecord(events[0]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // + // Run profiling loop + // + + for (int iter = 0; iter < this->options.iterations; ++iter) { + gemm(); + } + + // + // Stop profiling loop + // + + // Record an event when the GEMM operations have been launched. + result.error = cudaEventRecord(events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; return result; } // Wait for work on the device to complete. - result.error = cudaEventSynchronize(events[0]); + result.error = cudaEventSynchronize(events[1]); if (result.error != cudaSuccess) { std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; return result; @@ -1289,8 +1381,8 @@ public: } // Compute average runtime and GFLOPs. - result.runtime_ms = double(runtime_ms) / double(options.iterations); - result.gflops = options.gflops(result.runtime_ms / 1000.0); + result.runtime_ms = double(runtime_ms) / double(this->options.iterations); + result.gflops = this->options.gflops(result.runtime_ms / 1000.0); // // Cleanup @@ -1299,28 +1391,40 @@ public: for (auto event : events) { (void)cudaEventDestroy(event); } - - for (auto stream : cuda_streams) { - if (stream) { - (void)cudaStreamDestroy(stream); + + // Optionally profile initialization + if (this->options.profile_initialization) { + // Warm up + gemm.initialize(args, workspace.get()); + + auto start_time = std::chrono::high_resolution_clock::now(); + for (int32_t i = 0; i < this->options.iterations; ++i) { + gemm.initialize(args, workspace.get()); } + auto end_time = std::chrono::high_resolution_clock::now(); + + std::chrono::duration duration = end_time - start_time; + duration /= double(this->options.iterations); + result.initialization_time_ms = duration.count(); } - std::cout << std::endl; - std::cout << "Batched GEMM:\n" - << "====================================================" << std::endl; + int64_t total_tiles = Gemm::group_tile_count(args); + std::cout << " " << total_tiles << " total threadblock tiles." << std::endl; - std::cout << " " << bin_problem_sizes.size() << " batched GEMMs launched" << std::endl; std::cout << std::endl; - std::cout << " " << "Batched Runtime: " << result.runtime_ms << " ms" << std::endl; - std::cout << " " << "Batched GFLOPs: " << result.gflops << std::endl; + std::cout << " " << "Grouped Runtime: " << result.runtime_ms << " ms" << std::endl; + std::cout << " " << "Grouped GFLOPs: " << result.gflops << std::endl; + if (this->options.profile_initialization) { + std::cout << " " << "Init Runtime: " << result.initialization_time_ms << " ms" << std::endl; + } - if (options.output_file.good()) { - options.output_file << options.output_tag << "," << provider << ",batched," - << problem_count() << "," << result.runtime_ms << "," << result.gflops << std::endl; + if (this->options.output_file.good()) { + this->options.output_file << this->options.output_tag << ",CUTLASS,grouped-" << sched_mode << "," + << this->options.problem_count << "," << result.runtime_ms << "," << result.gflops << std::endl; } - result.passed = true; + std::cout << "\nPassed\n"; + return result; } }; @@ -1329,10 +1433,6 @@ public: int main(int argc, char const **args) { - // - // This example uses mma.sync to directly access Tensor Cores to achieve peak performance. - // - cudaDeviceProp props; cudaError_t error = cudaGetDeviceProperties(&props, 0); @@ -1342,12 +1442,12 @@ int main(int argc, char const **args) { } if (__CUDACC_VER_MAJOR__ < 11 || props.major < 8) { - + // // This example requires an NVIDIA Ampere-architecture GPU. // - std::cout + std::cout << "CUTLASS's Grouped GEMM example requires a GPU of NVIDIA's Ampere Architecture or " << "later (compute capability 80 or greater).\n"; @@ -1359,7 +1459,7 @@ int main(int argc, char const **args) { // Options options; - + options.parse(argc, args); if (options.help) { @@ -1373,9 +1473,11 @@ int main(int argc, char const **args) { } // - // Define the Grouped GEMM type + // Define the Grouped and Batched GEMM types // + using ElementA = cutlass::half_t; + using ElementB = cutlass::half_t; using ElementOutput = cutlass::half_t; using ElementAccumulator = float; @@ -1383,18 +1485,42 @@ int main(int argc, char const **args) { using LayoutB = cutlass::layout::ColumnMajor; using LayoutC = cutlass::layout::ColumnMajor; + // Gemm operator cutlass_tensorop_f16_s16816gemm_f16_128x128_32x4_nt_align8 + using GemmBatched = cutlass::gemm::device::GemmUniversal< + ElementA, LayoutA, + ElementB, LayoutB, + ElementOutput, LayoutC, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + 4 + >; + + // Define a grouped GEMM kernel with all template parameters set except + // for scheduling mode. This will be used as the template for all scheduling + // modes executed. using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGrouped< - cutlass::half_t, + ElementA, LayoutA, cutlass::ComplexTransform::kNone, 8, - cutlass::half_t, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, 8, ElementOutput, LayoutC, - ElementAccumulator, - cutlass::arch::OpClassTensorOp, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, cutlass::arch::Sm80, cutlass::gemm::GemmShape<128, 128, 32>, cutlass::gemm::GemmShape<64, 64, 32>, @@ -1402,64 +1528,50 @@ int main(int argc, char const **args) { cutlass::epilogue::thread::LinearCombination< ElementOutput, 128 / cutlass::sizeof_bits::value, ElementAccumulator, ElementAccumulator>, - cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, + // NOTE: Threadblock swizzling is currently not supported by CUTLASS's grouped kernels. + // This parameter is passed in at present to match the APIs of other kernels. The parameter + // is unused within the kernel. + cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, 4>::GemmKernel; using GemmGrouped = cutlass::gemm::device::GemmGrouped; - // - // Define a conventional batched GEMM type - // - - // Gemm operator cutlass_tensorop_f16_s16816gemm_f16_128x128_32x4_nt_align8 - using GemmBatched = cutlass::gemm::device::GemmUniversal< - cutlass::half_t, LayoutA, - cutlass::half_t, LayoutB, - ElementOutput, LayoutC, - ElementAccumulator, - cutlass::arch::OpClassTensorOp, - cutlass::arch::Sm80, - cutlass::gemm::GemmShape<128, 128, 32>, - cutlass::gemm::GemmShape<64, 64, 32>, - cutlass::gemm::GemmShape<16, 8, 16>, - cutlass::epilogue::thread::LinearCombination< - ElementOutput, - 128 / cutlass::sizeof_bits::value, - ElementAccumulator, - ElementAccumulator - >, - cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, - 4 - >; - // // Profile it // - TestbedGrouped testbed(options); - - if (!testbed.sufficient()) { - std::cout << "The active CUDA device lacks sufficient hardware resources to execute this kernel.\n"; - return 0; + TestbedBatched testbed_batched(options); + Result result = testbed_batched.profile(); + if (result.error) { + return 1; } - Result result = testbed.profile_grouped(); - if (!result.passed) { - std::cout << "Profiling CUTLASS grouped GEMM has failed.\n"; - std::cout << "\nFailed\n"; - return -1; - } + using GroupScheduleMode = cutlass::gemm::kernel::GroupScheduleMode; + for (GroupScheduleMode mode : options.scheduler_modes) { + Result result; + switch (mode) { + case GroupScheduleMode::kDeviceOnly: + { + TestbedGrouped runner(options); + result = runner.profile(); + break; + } + case GroupScheduleMode::kHostPrecompute: + { + TestbedGrouped runner(options); + result = runner.profile(); + break; + } + } - result = testbed.profile_batched(); - if (!result.passed) { + if (result.error != cudaSuccess) { + return 1; + } - std::cout << "Profiling batched GEMM has failed.\n"; - std::cout << "\nFailed\n"; - return -1; + // Override verbose flag to avoid printing duplicate information for each scheduling mode + options.verbose = false; } - std::cout << "\nPassed\n"; - return 0; } diff --git a/examples/25_ampere_fprop_mainloop_fusion/CMakeLists.txt b/examples/25_ampere_fprop_mainloop_fusion/CMakeLists.txt index 0bf0c775e9..ce9a0bd0c3 100644 --- a/examples/25_ampere_fprop_mainloop_fusion/CMakeLists.txt +++ b/examples/25_ampere_fprop_mainloop_fusion/CMakeLists.txt @@ -1,5 +1,5 @@ -# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without @@ -34,3 +34,8 @@ cutlass_example_add_executable( ampere_fprop_mainloop_fusion.cu ) +cutlass_example_add_executable( + 25_ampere_3d_fprop_mainloop_fusion + ampere_3d_fprop_mainloop_fusion.cu + ) + diff --git a/examples/25_ampere_fprop_mainloop_fusion/ampere_3d_fprop_mainloop_fusion.cu b/examples/25_ampere_fprop_mainloop_fusion/ampere_3d_fprop_mainloop_fusion.cu new file mode 100644 index 0000000000..a1ca2b078f --- /dev/null +++ b/examples/25_ampere_fprop_mainloop_fusion/ampere_3d_fprop_mainloop_fusion.cu @@ -0,0 +1,776 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/** + +This example shows how to fuse per channel scale+bias+relu of the activations +into the 3D fprop mainloop. + +Compared with original 3D fprop kernel, this example has two more vectors, one for +the scale and one for the bias. The length of the vectors is the same as the +activation channel number. This kernel loads the vectors when the associated +activation channels are loaded in the mainloop. Between reading the +activations and scale/bias data from the shared memory and calling tensor core +instructions, scale+bias+relu is computed in the register file. + +This example is customized for Ampere 16816 fp16 tensor core instruction. +Changing to different data types or different tensor core instruction require +source code changing. See +include/cutlass/conv/threadblock/implicit_gemm_fprop_fusion_multistage.h for more +technical details. + +This example is modified based on 25_ampere_fprop_mainloop_fusion. The command +line is the same. +*/ + +#include +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" +#include "cutlass/conv/kernel/default_conv3d_fprop_fusion.h" +#include "cutlass/conv/device/implicit_gemm_convolution_fusion.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/device/convolution.h" +#include "cutlass/util/tensor_view_io.h" + +#include "helper.h" + +// The code section below describes datatype for input, output tensors and computation between +// elements +using ElementAccumulator = float; // Data type of accumulator +using ElementComputeEpilogue = float; // Data type of epilogue computation (alpha, beta) +using ElementInputA = cutlass::half_t; // Data type of elements in input tensor +using ElementInputB = cutlass::half_t; // Data type of elements in input tensor +using ElementInputScaleBias = cutlass::half_t; // Data type of elements in input sclae and bias vectors +using ElementOutput = float; // Data type of elements in output tensor + +using LayoutInputA = cutlass::layout::TensorNDHWC; +using LayoutInputB = cutlass::layout::TensorNDHWC; +using LayoutInputScaleBias = cutlass::layout::RowMajor; +using LayoutOutput = cutlass::layout::TensorNDHWC; + +// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM +using MMAOp = cutlass::arch::OpClassTensorOp; + +// This code section describes CUDA SM architecture number +using SmArch = cutlass::arch::Sm80; + +// This code section describes the tile size a thread block will compute +using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; // Threadblock tile shape + +// This code section describes tile size a warp will compute +using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; // Warp tile shape + +// This code section describes the size of MMA op +using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; // TensorCore instruction shape + +// This code section describes how threadblocks are scheduled on GPU +using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; + +// Number of pipelines you want to use +constexpr int NumStages = 4; + +// This code section describe iterator algorithm selected is Analytic or Optimized +static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = cutlass::conv::IteratorAlgorithm::kOptimized; + +// This code section describes the epilogue part of the kernel, we use default value +using EpilogueOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, // Data type of output matrix. + 128 / cutlass::sizeof_bits::value, // The number of elements per vectorized. + // memory access. This becomes the vector width of + // math instructions in the epilogue too. + ElementAccumulator, // Data type of accumulator + ElementComputeEpilogue>; // Data type for alpha/beta in linear combination + +using Conv3dFpropFusionKernel = typename cutlass::conv::kernel::DefaultConv3dFpropFusion< + ElementInputA, LayoutInputA, + ElementInputB, LayoutInputB, + ElementInputScaleBias, LayoutInputScaleBias, + ElementOutput, LayoutOutput, + ElementAccumulator, + MMAOp, + SmArch, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOp, + SwizzleThreadBlock, + NumStages, + cutlass::arch::OpMultiplyAdd, + IteratorAlgorithm +>::Kernel; + +using ImplicitGemmFusion = cutlass::conv::device::ImplicitGemmConvolutionFusion; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + cutlass::Tensor5DCoord input_size; + cutlass::Tensor5DCoord filter_size; + cutlass::Coord<3> padding; + cutlass::Coord<3> conv_stride; + cutlass::Coord<3> dilation; + bool reference_check; + bool measure_performance; + int iterations; + bool save_workspace; + ElementComputeEpilogue alpha; + ElementComputeEpilogue beta; + bool benchmark; + std::string tag; + + Options(): + help(false), + input_size(1, 32, 32, 32, 32), + filter_size(32, 3, 3, 3, 32), + padding(cutlass::make_Coord(1, 1, 1)), + conv_stride(cutlass::make_Coord(1, 1, 1)), + dilation(cutlass::make_Coord(1, 1, 1)), + reference_check(true), + measure_performance(false), + iterations(20), + save_workspace(false), + alpha(1), + beta(0), + benchmark(false) { } + + // Verify the problem size is compatible with the CUTLASS Convolution implementation. + bool valid() { + + // + // CUTLASS attempts to load 128b vectors of cutlass::half_t (F16) elements. Consequently, + // all pointers, strides, and tensor extents must be divisible by 8 elements. + // + int const kAlignment = 8; + + if ((input_size.c() % kAlignment) || + (filter_size.n() % kAlignment)) { + + // misaligned tensors + return false; + } + + // Invalid padding + if ((padding[0] != filter_size.d() / 2) || + (padding[1] != filter_size.h() / 2) || + (padding[2] != filter_size.w() / 2)) { + + return false; + } + + return true; + } + + /// Updates input and filter sizes + void update( + cutlass::Tensor5DCoord input_size, + cutlass::Tensor5DCoord filter_size, + cutlass::Coord<3> stride) { + + this->input_size = input_size; + this->filter_size = filter_size; + conv_stride = stride; + + padding[0] = filter_size.d() / 2; + padding[1] = filter_size.h() / 2; + padding[2] = filter_size.w() / 2; + } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + } + + if (cmd.check_cmd_line_flag("ref-check")) { + reference_check = true; + } + + if (cmd.check_cmd_line_flag("perf-check")) { + measure_performance = true; + } + + if (cmd.check_cmd_line_flag("save-workspace")) { + save_workspace = true; + } + + if (cmd.check_cmd_line_flag("benchmark")) { + benchmark = true; + } + + cmd.get_cmd_line_argument("n", input_size.n()); + cmd.get_cmd_line_argument("d", input_size.d()); + cmd.get_cmd_line_argument("h", input_size.h()); + cmd.get_cmd_line_argument("w", input_size.w()); + cmd.get_cmd_line_argument("c", input_size.c()); + + cmd.get_cmd_line_argument("k", filter_size.n()); + cmd.get_cmd_line_argument("t", filter_size.d()); + cmd.get_cmd_line_argument("r", filter_size.h()); + cmd.get_cmd_line_argument("s", filter_size.w()); + filter_size.c() = input_size.c(); + + cmd.get_cmd_line_argument("alpha", alpha); + cmd.get_cmd_line_argument("beta", beta); + + cmd.get_cmd_line_argument("iterations", iterations); + cmd.get_cmd_line_argument("tag", tag); + + if (filter_size.d() == 3 && filter_size.h() == 3 && filter_size.w() == 3) { + padding = cutlass::make_Coord(1, 1, 1); + } + else { + filter_size.d() = 1; + filter_size.h() = 1; + filter_size.w() = 1; + padding = cutlass::make_Coord(0, 0, 0); + } + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "25_ampere_3d_fprop_mainloop_fusion example\n\n" + << " This example fuses scale+bias+relu of the activations into Ampere's\n" + << " Tensor Core operators on F16 data types to compute\n" + << " forward convolution on tensors of layout NDHWC.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement.\n\n" + << " --n Input tensor extent N\n" + << " --d Input tensor extent D\n" + << " --h Input tensor extent H\n" + << " --w Input tensor extent W\n" + << " --c Input tensor extent C\n" + << " --k Filter extent K\n" + << " --t Filter extent T\n" + << " --r Filter extent R\n" + << " --s Filter extent S\n\n" + << " --alpha Epilogue scalar alpha\n" + << " --beta Epilogue scalar beta\n\n" + << " --ref-check If set (true), reference check on the host is computed\n" + << " --perf-check If set (true), performance is measured.\n" + << " --benchmark If set (true), performance benchmarking on several layers and batch-size.\n" + << " --iterations Number of profiling iterations to perform.\n" + << " --save-workspace If set, workspace is written to a text file.\n" + << " --tag String to replicate across the first column in the results table\n"; + + out << "\n\nExamples:\n\n" + << "$ ./25_ampere_3d_fprop_mainloop_fusion --n=32 --d=96 --h=96 --w=96 --c=64 --k=64 --t=1 --r=1 --s=1\n\n" + << "$ ./25_ampere_3d_fprop_mainloop_fusion --n=1 --d=224 --h=224 --w=224 --c=32 --k=32 --t=3 --r=3 --s=3 --ref-check\n\n" + << "$ ./25_ampere_3d_fprop_mainloop_fusion --n=19 --d=94 --h=96 --w=96 --c=128 --k=128 --t=1 --r=1 --s=1\n\n"; + + return out; + } + + /// Computes the output tensor size (NPQK) + cutlass::Tensor5DCoord output_size() const { + return cutlass::Tensor5DCoord( + input_size.n(), + (input_size.d() + padding[0] + padding[0] - filter_size.d()) / conv_stride[0] + 1, + (input_size.h() + padding[1] + padding[1] - filter_size.h()) / conv_stride[1] + 1, + (input_size.w() + padding[2] + padding[2] - filter_size.w()) / conv_stride[2] + 1, + filter_size.n()); + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const { + + // Number of multiply-adds = NPQK * CRS + int64_t fmas = output_size().product() * int64_t(filter_size.d() * filter_size.h() * filter_size.w() * filter_size.c()); + + // Two flops per multiply-add + return 2.0 * double(fmas) / double(1.0e9) / runtime_s; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +struct Result { + double runtime_ms; + double gflops; + cutlass::Status status; + cutlass::Status reference_check; + cudaError_t error; + + Result(): + runtime_ms(0), + gflops(0), + status(cutlass::Status::kSuccess), + reference_check(cutlass::Status::kInvalid), + error(cudaSuccess) { } + + static std::ostream & print_header(std::ostream &out, Options const &options) { + + if (!options.tag.empty()) { + out << "Name,"; + } + + out << "Layer,N,D,H,W,C,K,T,R,S,Stride_D,Stride_H,Stride_W,Runtime,GFLOPs"; + + return out; + } + + std::ostream & print(std::ostream &out, int idx, Options const &options) { + + if (!options.tag.empty()) { + out << options.tag << ","; + } + + out + << "conv_" << idx << "," + << options.input_size.n() << "," + << options.input_size.d() << "," + << options.input_size.h() << "," + << options.input_size.w() << "," + << options.input_size.c() << "," + << options.filter_size.n() << "," + << options.filter_size.d() << "," + << options.filter_size.h() << "," + << options.filter_size.w() << "," + << options.conv_stride[0] << "," + << options.conv_stride[1] << "," + << options.conv_stride[2] << "," + << runtime_ms << "," + << gflops; + + return out; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Runs one benchmark +Result profile_convolution(Options const &options) { + + Result result; + + // + // Allocate host-device tensors using the CUTLASS Utilities. + // + + cutlass::HostTensor tensor_a(options.input_size); + cutlass::HostTensor tensor_transformed_a(options.input_size); + cutlass::HostTensor tensor_b(options.filter_size); + cutlass::HostTensor + tensor_a_scale({1, options.input_size.c()}); + cutlass::HostTensor + tensor_a_bias({1, options.input_size.c()}); + cutlass::HostTensor tensor_c(options.output_size()); + cutlass::HostTensor tensor_d(options.output_size()); + cutlass::HostTensor tensor_ref_d(options.output_size()); + + // + // Initialize tensors + // + + // Fill tensor A on host with uniform-distribution random data + cutlass::reference::host::TensorFillRandomUniform( + tensor_a.host_view(), + 1, + ElementInputA(3), + ElementInputA(-4), + 0); + + // Fill scale vector for tensor A on host with uniform-distribution random + // data + cutlass::reference::host::TensorFillRandomUniform( + tensor_a_scale.host_view(), + 1, + ElementInputA(3), + ElementInputA(-4), + 0); + + // Fill bias vector for tensor A on host with uniform-distribution random + // data + cutlass::reference::host::TensorFillRandomUniform( + tensor_a_bias.host_view(), + 1, + ElementInputA(3), + ElementInputA(-4), + 0); + + // Fill tensor B on host with uniform-distribution random data + cutlass::reference::host::TensorFillRandomUniform( + tensor_b.host_view(), + 1, + ElementInputB(7), + ElementInputB(-8), + 0); + + // Fill tensor C on host with uniform-distribution random data + cutlass::reference::host::TensorFillRandomUniform( + tensor_c.host_view(), + 1, + ElementOutput(7), + ElementOutput(-8), + 0); + + // Fill tensor D for reference on host with zeros + cutlass::reference::host::TensorFill( + tensor_ref_d.host_view()); + + // Copy data from host to GPU + tensor_a.sync_device(); + tensor_a_scale.sync_device(); + tensor_a_bias.sync_device(); + tensor_b.sync_device(); + tensor_c.sync_device(); + tensor_d.sync_device(); + tensor_ref_d.sync_device(); + + // + // Define arguments for CUTLASS Convolution + // + + cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation; + + // Split K dimension into 1 partitions + int split_k_slices = 1; + + // Construct Conv3dProblemSize with user defined output size + cutlass::conv::Conv3dProblemSize problem_size( + options.input_size, + options.filter_size, + options.padding, + options.conv_stride, + options.dilation, + options.output_size(), + mode, + split_k_slices + ); + + typename ImplicitGemmFusion::Arguments arguments{ + problem_size, + tensor_a.device_ref(), + tensor_b.device_ref(), + tensor_a_scale.device_ref(), + tensor_a_bias.device_ref(), + tensor_c.device_ref(), + tensor_d.device_ref(), + {options.alpha, options.beta}, + }; + + // + // Initialize CUTLASS Convolution + // + + ImplicitGemmFusion implicit_gemm_fusion_op; + + size_t workspace_size = implicit_gemm_fusion_op.get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + result.status = implicit_gemm_fusion_op.can_implement(arguments); + CUTLASS_CHECK(result.status); + + result.status = implicit_gemm_fusion_op.initialize(arguments, workspace.get()); + CUTLASS_CHECK(result.status); + + // + // Launch initialized CUTLASS kernel + // + result.status = implicit_gemm_fusion_op(); + + CUTLASS_CHECK(result.status); + + // + // Optional reference check + // + + if (options.reference_check) { + std::cout << "Verification on device...\n"; + + // Compute scale + bias + relu in host code + for (int n = 0; n < options.input_size.n(); ++n) { + for (int d = 0; d < options.input_size.d(); ++d) { + for (int h = 0; h < options.input_size.h(); ++h) { + for (int w = 0; w < options.input_size.w(); ++w) { + for (int c = 0; c < options.input_size.c(); ++c) { + tensor_transformed_a.at({n, d, h, w, c}) = std::max( + ElementOutput(0), ElementOutput(tensor_a.at({n, d, h, w, c}) * + tensor_a_scale.at({0, c}) + + tensor_a_bias.at({0, c}))); + } + } + } + } + } + + tensor_transformed_a.sync_device(); + + // Compute with reference implementation + cutlass::reference::device::Conv3dFprop< + ElementInputA, + LayoutInputA, + ElementInputB, + LayoutInputB, + ElementOutput, + LayoutOutput, + ElementComputeEpilogue, + ElementAccumulator, + cutlass::NumericConverter + >( + problem_size, + tensor_transformed_a.device_ref(), + tensor_b.device_ref(), + tensor_c.device_ref(), + tensor_ref_d.device_ref(), + options.alpha, + options.beta + ); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + tensor_d.sync_host(); + tensor_ref_d.sync_host(); + + bool passed = cutlass::reference::host::TensorEquals( + tensor_d.host_view(), + tensor_ref_d.host_view()); + + if (!passed) { + result.reference_check = cutlass::Status::kErrorInternal; + std::cout << "ERROR - results miscompared.\n"; + } + else { + result.reference_check = cutlass::Status::kSuccess; + std::cout << "Passed.\n"; + } + } + else { + result.reference_check = cutlass::Status::kInvalid; + } + + if (options.save_workspace) { + + std::stringstream ss; + + ss << "25_ampere_3d_fprop_mainloop_fusion" + << options.input_size.n() << "x" << options.input_size.h() << "x" << options.input_size.w() << "x" << options.input_size.c() + << "_" + << options.filter_size.n() << "x" << options.filter_size.h() << "x" << options.filter_size.w() << "x" << options.filter_size.c() + << ".dat"; + + std::ofstream output_workspace(ss.str()); + + output_workspace + << "Input = \n" << tensor_a.host_view() << "\n\n" + << "Filters = \n" << tensor_b.host_view() << "\n\n"; + + if (options.reference_check) { + output_workspace << "Reference = \n" << tensor_ref_d.host_view() << "\n\n"; + } + + output_workspace << "Computed = \n" << tensor_d.host_view() << std::endl; + + std::cout << "Results written to '" << ss.str() << "'." << std::endl; + } + + // + // Performance measurement + // + + if (options.measure_performance) { + + cudaEvent_t events[2]; + + for (auto & event : events) { + result.error = cudaEventCreate(&event); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + } + + // Record an event at the start of a series of convolution operations. + result.error = cudaEventRecord(events[0]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Launch a sequence of implicit GEMM operations on the device + for (int iteration = 0; iteration < options.iterations; ++iteration) { + result.status = implicit_gemm_fusion_op(); + CUTLASS_CHECK(result.status); + } + + // Record an event when the convolutions have been launched. + result.error = cudaEventRecord(events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Wait for work on the device to complete. + result.error = cudaEventSynchronize(events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Measure elapsed runtime + float runtime_ms = 0; + result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Print average runtime and GFLOPs. + result.runtime_ms = double(runtime_ms) / double(options.iterations); + result.gflops = options.gflops(result.runtime_ms / 1000.0); + + // Cleanup + for (auto event : events) { + (void)cudaEventDestroy(event); + } + } + + return result; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + bool notSupported = false; + + // Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0. + // + // CUTLASS must be compiled with CUDA 11 Toolkit to run Conv3dFprop examples. + if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))) { + std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl; + notSupported = true; + } + + cudaDeviceProp props; + CUDA_CHECK(cudaGetDeviceProperties(&props, 0)); + + if (!(props.major >= 8)) { + std::cerr << "This test must run on SM80 or above.\n"; + notSupported = true; + } + + if (notSupported) { + return 0; + } + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.benchmark) { + // Benchmark several layers + + int batch_sizes[] = {34, 18}; + + struct Benchmark { + int d, h, w, c, k, t, r, s, stride_d, stride_h, stride_w; + } layers[] = { + {56, 56, 56, 64, 256, 1, 1, 1, 1, 1, 1}, + {56, 56, 56, 64, 64, 1, 1, 1, 1, 1, 1}, + {56, 56, 56, 64, 64, 3, 3, 3, 1, 1, 1}, + {56, 56, 56, 256, 64, 1, 1, 1, 1, 1, 1}, + {56, 56, 56, 256, 512, 1, 1, 1, 2, 2, 2}, + {56, 56, 56, 256, 128, 1, 1, 1, 1, 1, 1}, + {56, 56, 56, 128, 128, 3, 3, 3, 2, 2, 2}, + {28, 28, 28, 128, 512, 1, 1, 1, 1, 1, 1}, + {28, 28, 28, 512, 128, 1, 1, 1, 1, 1, 1}, + {28, 28, 28, 128, 128, 3, 3, 3, 1, 1, 1}, + {28, 28, 28, 512, 1024, 1, 1, 1, 2, 2, 2}, + {28, 28, 28, 512, 256, 1, 1, 1, 1, 1, 1}, + {28, 28, 28, 256, 256, 3, 3, 3, 2, 2, 2}, + {14, 14, 14, 256, 1024, 1, 1, 1, 1, 1, 1}, + {14, 14, 14, 1024, 256, 1, 1, 1, 1, 1, 1}, + {14, 14, 14, 256, 256, 3, 3, 3, 1, 1, 1}, + {14, 14, 14, 1024, 2048, 1, 1, 1, 2, 2, 2}, + {14, 14, 14, 1024, 512, 1, 1, 1, 1, 1, 1}, + {14, 14, 14, 512, 512, 3, 3, 3, 2, 2, 2}, + { 7, 7, 7, 512, 2048, 1, 1, 1, 1, 1, 1}, + { 7, 7, 7, 2048, 512, 1, 1, 1, 1, 1, 1}, + { 7, 7, 7, 512, 512, 3, 3, 3, 1, 1, 1}, + }; + + Result::print_header(std::cout, options) << std::endl; + + int idx = 1; + + for (auto const &layer : layers) { + for (auto N : batch_sizes) { + options.update({N, layer.d, layer.h, layer.w, layer.c}, + {layer.k, layer.t, layer.r, layer.s, layer.c}, + cutlass::make_Coord(layer.stride_d, layer.stride_h, layer.stride_w)); + + Result result = profile_convolution(options); + result.print(std::cout, idx, options) << std::endl; + } + + ++idx; + } + } + else { + + // Execute one problem size + if (!options.valid()) { + std::cerr << "Invalid problem." << std::endl; + return -1; + } + + Result result = profile_convolution(options); + + Result::print_header(std::cout, options) << std::endl; + result.print(std::cout, 1, options) << std::endl; + } + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/25_ampere_fprop_mainloop_fusion/ampere_fprop_mainloop_fusion.cu b/examples/25_ampere_fprop_mainloop_fusion/ampere_fprop_mainloop_fusion.cu index fe756fbadd..87ed21c013 100644 --- a/examples/25_ampere_fprop_mainloop_fusion/ampere_fprop_mainloop_fusion.cu +++ b/examples/25_ampere_fprop_mainloop_fusion/ampere_fprop_mainloop_fusion.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -429,9 +429,13 @@ Result profile_convolution(Options const &options) { ElementInputB(-8), 0); - // Fill tensor C on host with zeros - cutlass::reference::host::TensorFill( - tensor_c.host_view()); + // Fill tensor C on host with uniform-distribution random data + cutlass::reference::host::TensorFillRandomUniform( + tensor_c.host_view(), + 1, + ElementOutput(7), + ElementOutput(-8), + 0); // Fill tensor D on host with zeros cutlass::reference::host::TensorFill( @@ -575,7 +579,7 @@ Result profile_convolution(Options const &options) { std::stringstream ss; - ss << "25_ampere_fprop_mainloop_fusion_" + ss << "25_ampere_fprop_mainloop_fusion" << options.input_size.n() << "x" << options.input_size.h() << "x" << options.input_size.w() << "x" << options.input_size.c() << "_" << options.filter_size.n() << "x" << options.filter_size.h() << "x" << options.filter_size.w() << "x" << options.filter_size.c() @@ -677,8 +681,8 @@ int main(int argc, char const **args) { cudaDeviceProp props; CUDA_CHECK(cudaGetDeviceProperties(&props, 0)); - if (!(props.major == 8 && props.minor == 0)) { - std::cerr << "This test must run on SM80 A100.\n"; + if (!(props.major >= 8)) { + std::cerr << "This test must run on SM80 or above.\n"; notSupported = true; } diff --git a/examples/26_ampere_wgrad_mainloop_fusion/CMakeLists.txt b/examples/26_ampere_wgrad_mainloop_fusion/CMakeLists.txt index f836f6e08c..e96050c370 100644 --- a/examples/26_ampere_wgrad_mainloop_fusion/CMakeLists.txt +++ b/examples/26_ampere_wgrad_mainloop_fusion/CMakeLists.txt @@ -1,5 +1,5 @@ -# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without diff --git a/examples/26_ampere_wgrad_mainloop_fusion/ampere_wgrad_mainloop_fusion.cu b/examples/26_ampere_wgrad_mainloop_fusion/ampere_wgrad_mainloop_fusion.cu index 72d7284f6f..abb66b52d6 100644 --- a/examples/26_ampere_wgrad_mainloop_fusion/ampere_wgrad_mainloop_fusion.cu +++ b/examples/26_ampere_wgrad_mainloop_fusion/ampere_wgrad_mainloop_fusion.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -266,8 +266,8 @@ struct Options { /// Prints the usage statement. std::ostream & print_usage(std::ostream &out) const { - out << "26_ampere_fused_wgrad_batch_normalization example\n\n" - << " This example fuses scale+bias+relu from batch norm into Ampere's\n" + out << "26_ampere_wgrad_mainloop_fusion example\n\n" + << " This example fuses scale+bias+relu of the activation into Ampere's\n" << " Tensor Core operators on F16 data types to compute\n" << " backward convolution on tensors of layout NHWC.\n\n" << "Options:\n\n" @@ -289,8 +289,8 @@ struct Options { << " --tag= String to replicate across the first column in the results table\n"; out << "\n\nExamples:\n\n" - << "$ ./examples/26_ampere_fused_fprop_batch_normalization/26_ampere_fused_wgrad_batch_normalization --n=32 --h=224 --w=224 --c=128 --k=256 --r=1 --s=1\n\n" - << "$ ./examples/26_ampere_fused_fprop_batch_normalization/26_ampere_fused_wgrad_batch_normalization --n=1 --h=224 --w=224 --c=32 --k=32 --r=3 --s=3 --ref-check\n\n"; + << "$ ./examples/26_ampere_wgrad_mainloop_fusion/26_ampere_wgrad_mainloop_fusion --n=32 --h=224 --w=224 --c=128 --k=256 --r=1 --s=1\n\n" + << "$ ./examples/26_ampere_wgrad_mainloop_fusion/26_ampere_wgrad_mainloop_fusion --n=1 --h=224 --w=224 --c=32 --k=32 --r=3 --s=3 --ref-check\n\n"; return out; } @@ -427,9 +427,13 @@ Result profile_convolution(Options const &options) { ElementInputA(-4), 0); - // Fill tensor C on host with zeros - cutlass::reference::host::TensorFill( - tensor_c.host_view()); + // Fill tensor C on host with uniform-distribution random data + cutlass::reference::host::TensorFillRandomUniform( + tensor_c.host_view(), + 1, + ElementOutput(7), + ElementOutput(-8), + 0); // Fill tensor D on host with zeros cutlass::reference::host::TensorFill( diff --git a/examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm/27_ampere_3xtf32_fast_accurate_tensorop_gemm.cu b/examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm/27_ampere_3xtf32_fast_accurate_tensorop_gemm.cu index 06559637fe..9e561cb6a2 100644 --- a/examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm/27_ampere_3xtf32_fast_accurate_tensorop_gemm.cu +++ b/examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm/27_ampere_3xtf32_fast_accurate_tensorop_gemm.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -36,7 +36,7 @@ implicitly to tf32 inside the GEMM kernel which means no change is needed to acc fp32 data by using NVIDIA Ampere architecture. We can use the tf32 mode of tensor core to emulate a fast accurate SGEMM kernel which is accelerated -using Ampere Tensor Cores (see include/cutlass/gemm/warp/mma_tensor_op_fast_f32.h). +using Ampere Tensor Cores (see include/cutlass/gemm/warp/mma_tensor_op_fast_f32.h). The trick is very simple a x b = (a_big + a_small) x (b_big + b_small) = a_big x b_big + a_big x b_small + a_small x b_big @@ -45,11 +45,11 @@ The trick is very simple a_small x b_small is discarded because they are too small. -This example demonstrates usage of this kernel, along with accuracy measurements w.r.t. actual FP32 +This example demonstrates usage of this kernel, along with accuracy measurements w.r.t. actual FP32 results (SGEMM using SIMT) and against FP64 results (DGEMM) -To enable this feature, the only change needs to make is to change the default OpMultiplyAdd to -OpMultiplyAddFastF32. +To enable this feature, the only change needs to make is to change the default OpMultiplyAdd to +OpMultiplyAddFastF32. Now, we have several different flavors of sgemm now in the profiler for Ampere. Here are the difference @@ -97,14 +97,14 @@ struct Result { double l2_norm_fp32_vs_fp64; // ctor - Result( + Result( int m, int n, int k, double runtime_ms, double gflops, double l2_norm_3xtf32_vs_fp64, double l2_norm_1xtf32_vs_fp64, - double l2_norm_fp32_vs_fp64) : + double l2_norm_fp32_vs_fp64) : m(m), n(n), k(k), - runtime_ms(runtime_ms), gflops(gflops), + runtime_ms(runtime_ms), gflops(gflops), l2_norm_3xtf32_vs_fp64(l2_norm_3xtf32_vs_fp64), l2_norm_1xtf32_vs_fp64(l2_norm_1xtf32_vs_fp64), l2_norm_fp32_vs_fp64(l2_norm_fp32_vs_fp64) {} @@ -147,7 +147,7 @@ struct Options { int iterations; int seed; bool benchmark; - + Options(): help(false), problem_size({3456, 4096, 4096}), @@ -190,7 +190,7 @@ struct Options { cmd.get_cmd_line_argument("alpha", alpha); cmd.get_cmd_line_argument("beta", beta); - + cmd.get_cmd_line_argument("iterations", iterations); cmd.get_cmd_line_argument("seed", seed); cmd.get_cmd_line_argument("rand_mode", rand_mode); @@ -227,9 +227,9 @@ struct Options { /// Compute performance in GFLOP/s double gflops(double runtime_s) const { - // Number of real-valued multiply-adds + // Number of real-valued multiply-adds int64_t fmas = problem_size.product(); - + // Two flops per multiply-add return 2.0 * double(fmas) / double(1.0e9) / runtime_s; } @@ -258,7 +258,7 @@ using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 32, 16>; // <- warp tile M = using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 8>; // <- MMA Op tile M = 16, N = 8, K = 8 // This code section describes how threadblocks are scheduled on GPU -using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? +using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // This code section describes the epilogue part of the kernel using EpilogueOp = cutlass::epilogue::thread::LinearCombination< @@ -272,10 +272,10 @@ using EpilogueOp = cutlass::epilogue::thread::LinearCombination< // Number of pipelines you want to use constexpr int NumStages = 3; -// Alignment +// Alignment constexpr int Alignment = 4; -// +// // Gemm Operators (Gemm_3xTF32, Gemm_1xTF32, GEMM_F32, GEMM_F64) // @@ -296,7 +296,7 @@ using Gemm_3xTF32 = cutlass::gemm::device::Gemm< EpilogueOp, SwizzleThreadBlock, NumStages, - Alignment, + Alignment, Alignment, false, cutlass::arch::OpMultiplyAddFastF32>; @@ -318,7 +318,7 @@ using Gemm_1xTF32 = cutlass::gemm::device::Gemm< EpilogueOp, SwizzleThreadBlock, NumStages, - Alignment, + Alignment, Alignment, false, cutlass::arch::OpMultiplyAdd>; @@ -356,7 +356,7 @@ bool run(Options &options) { cutlass::HostTensor tensor_a_F32(problem_size.mk()); // <- Create matrix A with dimensions M x K cutlass::HostTensor tensor_b_F32(problem_size.kn()); // <- Create matrix B with dimensions K x N cutlass::HostTensor tensor_c_F32(problem_size.mn()); // <- Create matrix C with dimensions M x N - cutlass::HostTensor tensor_d_F32(problem_size.mn()); // <- Create matrix D with dimensions M x N + cutlass::HostTensor tensor_d_F32(problem_size.mn()); // <- Create matrix D with dimensions M x N if (options.rand_mode == "uniform") { const float min = -1; @@ -397,7 +397,7 @@ bool run(Options &options) { } cutlass::reference::host::TensorFill( tensor_d_F32.host_view()); // <- fill matrix D on host with zeros - + // Copy data from host to GPU tensor_a_F32.sync_device(); tensor_b_F32.sync_device(); @@ -411,7 +411,7 @@ bool run(Options &options) { cutlass::HostTensor tensor_a_F64(problem_size.mk()); // <- Create matrix A with dimensions M x K cutlass::HostTensor tensor_b_F64(problem_size.kn()); // <- Create matrix B with dimensions K x N cutlass::HostTensor tensor_c_F64(problem_size.mn()); // <- Create matrix C with dimensions M x N - + // Gemm output (D) for GEMM_F64 cutlass::HostTensor tensor_d_F64(problem_size.mn()); // <- Create matrix D with dimensions M x N // Gemm output (D) for GEMM_3xTF32 @@ -426,7 +426,7 @@ bool run(Options &options) { cutlass::reference::host::TensorCopy(tensor_d_F64.host_view(), tensor_d_F32.host_view()); cutlass::reference::host::TensorCopy(tensor_d_3xTF32.host_view(), tensor_d_F32.host_view()); cutlass::reference::host::TensorCopy(tensor_d_1xTF32.host_view(), tensor_d_F32.host_view()); - + // Copy data from host to GPU tensor_a_F64.sync_device(); tensor_b_F64.sync_device(); @@ -464,7 +464,7 @@ bool run(Options &options) { // Instantiate CUTLASS kernel depending on templates Gemm_3xTF32 gemm_op_3xTF32; - // Check the problem size is supported or not + // Check the problem size is supported or not cutlass::Status status_3xtf32 = gemm_op_3xTF32.can_implement(arguments_3xtf32); CUTLASS_CHECK(status_3xtf32); @@ -568,7 +568,7 @@ bool run(Options &options) { // Instantiate CUTLASS kernel depending on templates Gemm_1xTF32 gemm_op_1xtf32; - // Check the problem size is supported or not + // Check the problem size is supported or not cutlass::Status status_1xtf32 = gemm_op_1xtf32.can_implement(arguments_1xtf32); CUTLASS_CHECK(status_1xtf32); @@ -627,7 +627,7 @@ bool run(Options &options) { tensor_d_F32.sync_host(); //////////////////////////////////////////////////////////////////////////////// - /////// Compute l2 norms + /////// Compute l2 norms //////////////////////////////////////////////////////////////////////////////// // l2 norm 3xTF32 vs F64 @@ -664,7 +664,7 @@ bool run(Options &options) { std::cout << "GFLOPs: " << result.gflops << std::endl; std::cout << "Normalized L2 norm of" << std::endl; std::cout.precision(8); - std::cout << std::scientific + std::cout << std::scientific << " - 3xTF32 error with FP64 reference : " << result.l2_norm_3xtf32_vs_fp64 << std::endl << " - 1xTF32 error with FP64 reference : " << result.l2_norm_1xtf32_vs_fp64 << std::endl << " - FP32 error with FP64 reference : " << result.l2_norm_fp32_vs_fp64 << std::endl; @@ -673,11 +673,11 @@ bool run(Options &options) { } int main(int argc, const char **argv) { - + bool notSupported = false; // Ampere Tensor Core operations exposed with mma.sync and ldmatrix are first available - // in CUDA 11.0. + // in CUDA 11.0. // // CUTLASS must be compiled with CUDA 11.0 Toolkit to run these examples. if (!(__CUDACC_VER_MAJOR__ >= 11)) { @@ -690,7 +690,7 @@ int main(int argc, const char **argv) { cudaError_t error = cudaGetDeviceProperties(&props, 0); if (error != cudaSuccess) { std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; - return false; + return -1; } if (!((props.major * 10 + props.minor) >= 80)) { @@ -716,17 +716,17 @@ int main(int argc, const char **argv) { if (options.benchmark) { for (int k = 4; k <= 65536; k *= 2) { - + options.problem_size[2] = k; - + printf("Gemm problem size: %d x %d x %d\n", \ options.problem_size.m(), options.problem_size.n(), options.problem_size.k()); - + if (!options.valid()) { std::cerr << "Invalid problem." << std::endl; return -1; } - + result &= run(options); } } else { diff --git a/examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm/CMakeLists.txt b/examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm/CMakeLists.txt index c551b1256f..5b38de6e91 100644 --- a/examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm/CMakeLists.txt +++ b/examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm/CMakeLists.txt @@ -1,5 +1,5 @@ -# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without diff --git a/examples/28_ampere_3xtf32_fast_accurate_tensorop_fprop/CMakeLists.txt b/examples/28_ampere_3xtf32_fast_accurate_tensorop_fprop/CMakeLists.txt index 04ac4bd183..50a7c9e619 100644 --- a/examples/28_ampere_3xtf32_fast_accurate_tensorop_fprop/CMakeLists.txt +++ b/examples/28_ampere_3xtf32_fast_accurate_tensorop_fprop/CMakeLists.txt @@ -1,5 +1,5 @@ -# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without diff --git a/examples/28_ampere_3xtf32_fast_accurate_tensorop_fprop/ampere_3xtf32_fast_accurate_tensorop_fprop.cu b/examples/28_ampere_3xtf32_fast_accurate_tensorop_fprop/ampere_3xtf32_fast_accurate_tensorop_fprop.cu index a197e2efce..d2a3b4c693 100644 --- a/examples/28_ampere_3xtf32_fast_accurate_tensorop_fprop/ampere_3xtf32_fast_accurate_tensorop_fprop.cu +++ b/examples/28_ampere_3xtf32_fast_accurate_tensorop_fprop/ampere_3xtf32_fast_accurate_tensorop_fprop.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -740,7 +740,7 @@ int main(int argc, char const **args) { cudaDeviceProp props; CUDA_CHECK(cudaGetDeviceProperties(&props, 0)); - if (!(props.major > 8 || (props.major == 8 && props.minor >= 0))) { + if (!(props.major >= 8)) { std::cerr << "Ampere Tensor Ops must be run on a machine with compute capability at least 80." << std::endl; notSupported = true; diff --git a/examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm.cu b/examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/29_3xtf32_complex_gemm.cu similarity index 97% rename from examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm.cu rename to examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/29_3xtf32_complex_gemm.cu index 6f89d64fa7..0a995bf929 100644 --- a/examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm.cu +++ b/examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/29_3xtf32_complex_gemm.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -34,7 +34,7 @@ difference is that this example uses 3xtf32 on complex gemm. To enable this feature, the only change needs to make is to change OpMultiplyAddComplex - to OpMultiplyAddComplexFastF32. + to OpMultiplyAddComplexFastF32. */ #include @@ -74,14 +74,14 @@ struct Result { double l2_norm_fp32_vs_fp64; // ctor - Result( + Result( int m, int n, int k, double runtime_ms, double gflops, double l2_norm_3xtf32_vs_fp64, double l2_norm_1xtf32_vs_fp64, - double l2_norm_fp32_vs_fp64) : + double l2_norm_fp32_vs_fp64) : m(m), n(n), k(k), - runtime_ms(runtime_ms), gflops(gflops), + runtime_ms(runtime_ms), gflops(gflops), l2_norm_3xtf32_vs_fp64(l2_norm_3xtf32_vs_fp64), l2_norm_1xtf32_vs_fp64(l2_norm_1xtf32_vs_fp64), l2_norm_fp32_vs_fp64(l2_norm_fp32_vs_fp64) {} @@ -124,7 +124,7 @@ struct Options { int iterations; int seed; bool benchmark; - + Options(): help(false), problem_size({3456, 4096, 4096}), @@ -153,7 +153,7 @@ struct Options { cmd.get_cmd_line_argument("alpha", alpha); cmd.get_cmd_line_argument("beta", beta); - + cmd.get_cmd_line_argument("iterations", iterations); cmd.get_cmd_line_argument("seed", seed); cmd.get_cmd_line_argument("rand_mode", rand_mode); @@ -181,7 +181,7 @@ struct Options { << " --benchmark If set (true), performance benchmarking on several layers and batch-size.\n\n"; out << "\n\nExamples:\n\n" - << "$ ./examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/29_ampere_3xtf32_fast_accurate_complex_gemm --m=1024 --n=512 \\\n" + << "$ ./examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/29_3xtf32_complex_gemm --m=1024 --n=512 \\\n" << " --alpha=2 --beta=0.707 \n\n"; return out; @@ -190,9 +190,9 @@ struct Options { /// Compute performance in GFLOP/s double gflops(double runtime_s) const { - // Number of real-valued multiply-adds + // Number of real-valued multiply-adds int64_t fmas = problem_size.product(); - + // Two flops per multiply-add return 2.0 * double(fmas) / double(1.0e9) / runtime_s; } @@ -221,7 +221,7 @@ using ShapeMMAWarp = cutlass::gemm::GemmShape<32, 32, 16>; // <- warp tile M = using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 8>; // <- MMA Op tile M = 16, N = 8, K = 8 // This code section describes how threadblocks are scheduled on GPU -using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? +using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // This code section describes the epilogue part of the kernel using EpilogueOp = cutlass::epilogue::thread::LinearCombination< @@ -239,7 +239,7 @@ constexpr int NumStages = 3; constexpr cutlass::ComplexTransform TransformA = cutlass::ComplexTransform::kNone; constexpr cutlass::ComplexTransform TransformB = cutlass::ComplexTransform::kNone; -// +// // Gemm Operators (Gemm_3xTF32, Gemm_1xTF32, GEMM_F32, GEMM_F64) // @@ -260,7 +260,7 @@ using Gemm_3xTF32 = cutlass::gemm::device::GemmComplex< EpilogueOp, SwizzleThreadBlock, NumStages, - TransformA, + TransformA, TransformB, cutlass::arch::OpMultiplyAddComplexFastF32>; @@ -281,7 +281,7 @@ using Gemm_1xTF32 = cutlass::gemm::device::GemmComplex< EpilogueOp, SwizzleThreadBlock, NumStages, - TransformA, + TransformA, TransformB, cutlass::arch::OpMultiplyAddComplex>; @@ -296,7 +296,7 @@ bool run(Options &options) { cutlass::HostTensor, LayoutInputA> tensor_a_F32(problem_size.mk()); // <- Create matrix A with dimensions M x K cutlass::HostTensor, LayoutInputB> tensor_b_F32(problem_size.kn()); // <- Create matrix B with dimensions K x N cutlass::HostTensor, LayoutOutput> tensor_c_F32(problem_size.mn()); // <- Create matrix C with dimensions M x N - cutlass::HostTensor, LayoutOutput> tensor_d_F32(problem_size.mn()); // <- Create matrix D with dimensions M x N + cutlass::HostTensor, LayoutOutput> tensor_d_F32(problem_size.mn()); // <- Create matrix D with dimensions M x N if (options.rand_mode == "uniform") { const float min = -1; @@ -337,7 +337,7 @@ bool run(Options &options) { } cutlass::reference::host::TensorFill( tensor_d_F32.host_view()); // <- fill matrix D on host with zeros - + // Copy data from host to GPU tensor_a_F32.sync_device(); tensor_b_F32.sync_device(); @@ -351,7 +351,7 @@ bool run(Options &options) { cutlass::HostTensor, LayoutInputA> tensor_a_F64(problem_size.mk()); // <- Create matrix A with dimensions M x K cutlass::HostTensor, LayoutInputB> tensor_b_F64(problem_size.kn()); // <- Create matrix B with dimensions K x N cutlass::HostTensor, LayoutOutput> tensor_c_F64(problem_size.mn()); // <- Create matrix C with dimensions M x N - + // Gemm output (D) for GEMM_F64 cutlass::HostTensor, LayoutOutput> tensor_d_F64(problem_size.mn()); // <- Create matrix D with dimensions M x N // Gemm output (D) for GEMM_3xTF32 @@ -366,7 +366,7 @@ bool run(Options &options) { cutlass::reference::host::TensorCopy(tensor_d_F64.host_view(), tensor_d_F32.host_view()); cutlass::reference::host::TensorCopy(tensor_d_3xTF32.host_view(), tensor_d_F32.host_view()); cutlass::reference::host::TensorCopy(tensor_d_1xTF32.host_view(), tensor_d_F32.host_view()); - + // Copy data from host to GPU tensor_a_F64.sync_device(); tensor_b_F64.sync_device(); @@ -404,7 +404,7 @@ bool run(Options &options) { // Instantiate CUTLASS kernel depending on templates Gemm_3xTF32 gemm_op; - // Check the problem size is supported or not + // Check the problem size is supported or not cutlass::Status status_3xtf32 = gemm_op.can_implement(arguments_3xtf32); CUTLASS_CHECK(status_3xtf32); @@ -508,7 +508,7 @@ bool run(Options &options) { // Instantiate CUTLASS kernel depending on templates Gemm_1xTF32 gemm_op_1xtf32; - // Check the problem size is supported or not + // Check the problem size is supported or not cutlass::Status status_1xtf32 = gemm_op_1xtf32.can_implement(arguments_1xtf32); CUTLASS_CHECK(status_1xtf32); @@ -569,7 +569,7 @@ bool run(Options &options) { tensor_d_F32.sync_host(); //////////////////////////////////////////////////////////////////////////////// - /////// Compute l2 norms + /////// Compute l2 norms //////////////////////////////////////////////////////////////////////////////// // l2 norm 3xTF32 vs F64 @@ -606,7 +606,7 @@ bool run(Options &options) { std::cout << "GFLOPs: " << result.gflops << std::endl; std::cout << "Normalized L2 norm of" << std::endl; std::cout.precision(8); - std::cout << std::scientific + std::cout << std::scientific << " - 3xTF32 error with FP64 reference : " << result.l2_norm_3xtf32_vs_fp64 << std::endl << " - 1xTF32 error with FP64 reference : " << result.l2_norm_1xtf32_vs_fp64 << std::endl << " - FP32 error with FP64 reference : " << result.l2_norm_fp32_vs_fp64 << std::endl; @@ -615,11 +615,11 @@ bool run(Options &options) { } int main(int argc, const char **argv) { - + bool notSupported = false; // Ampere Tensor Core operations exposed with mma.sync and ldmatrix are first available - // in CUDA 11.0. + // in CUDA 11.0. // // CUTLASS must be compiled with CUDA 11.0 Toolkit to run these examples. if (!(__CUDACC_VER_MAJOR__ >= 11)) { @@ -632,7 +632,7 @@ int main(int argc, const char **argv) { cudaError_t error = cudaGetDeviceProperties(&props, 0); if (error != cudaSuccess) { std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; - return false; + return -1; } if (!((props.major * 10 + props.minor) >= 80)) { @@ -658,17 +658,17 @@ int main(int argc, const char **argv) { if (options.benchmark) { for (int k = 4; k <= 65536; k *= 2) { - + options.problem_size[2] = k; - + printf("Gemm problem size: %d x %d x %d\n", \ options.problem_size.m(), options.problem_size.n(), options.problem_size.k()); - + if (!options.valid()) { std::cerr << "Invalid problem." << std::endl; return -1; } - + result &= run(options); } } else { diff --git a/examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/CMakeLists.txt b/examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/CMakeLists.txt index c7e896ba90..e406a7eda7 100644 --- a/examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/CMakeLists.txt +++ b/examples/29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm/CMakeLists.txt @@ -1,5 +1,5 @@ -# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without @@ -27,9 +27,9 @@ # OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE # OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. - +# Both filenames are shorter to avoid MAX_PATH issues on Windows. cutlass_example_add_executable( - 29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm - 29_ampere_3xtf32_fast_accurate_tensorop_complex_gemm.cu + 29_3xtf32_complex_gemm + 29_3xtf32_complex_gemm.cu ) diff --git a/examples/30_wgrad_split_k/30_wgrad_split_k.cu b/examples/30_wgrad_split_k/30_wgrad_split_k.cu index 5016adf292..822a7a55f8 100644 --- a/examples/30_wgrad_split_k/30_wgrad_split_k.cu +++ b/examples/30_wgrad_split_k/30_wgrad_split_k.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -531,17 +531,17 @@ Result profile_convolution(Options const &options) { // Reduction input { reinterpret_cast (workspace.get()), - ReductionStrideIndex(tensor_c.stride()[ImplicitGemm::ImplicitGemmKernel::kTensorCStrideIdx]) + ReductionStrideIndex(tensor_c.stride()[ImplicitGemm::UnderlyingKernel::kTensorCStrideIdx]) }, // Destination { tensor_d.device_data(), - ReductionStrideIndex(tensor_d.stride()[ImplicitGemm::ImplicitGemmKernel::kTensorCStrideIdx]) + ReductionStrideIndex(tensor_d.stride()[ImplicitGemm::UnderlyingKernel::kTensorCStrideIdx]) }, // Source { tensor_c.device_data(), - ReductionStrideIndex(tensor_c.stride()[ImplicitGemm::ImplicitGemmKernel::kTensorCStrideIdx]) + ReductionStrideIndex(tensor_c.stride()[ImplicitGemm::UnderlyingKernel::kTensorCStrideIdx]) }, {options.alpha, options.beta} ); @@ -602,7 +602,7 @@ Result profile_convolution(Options const &options) { std::stringstream ss; - ss << "26_ampere_fused_wgrad_batch_normalization_" + ss << "30_wgrad_split_k_" << options.input_size.n() << "x" << options.input_size.h() << "x" << options.input_size.w() << "x" << options.input_size.c() << "_" << options.filter_size.n() << "x" << options.filter_size.h() << "x" << options.filter_size.w() << "x" << options.filter_size.c() @@ -703,7 +703,7 @@ int main(int argc, char const **args) { cudaDeviceProp props; CUDA_CHECK(cudaGetDeviceProperties(&props, 0)); - if (!(props.major > 8 || (props.major == 8 && props.minor >= 0))) { + if (!(props.major >= 8)) { std::cerr << "Ampere Tensor Ops must be run on a machine with compute capability at least 80." << std::endl; notSupported = true; diff --git a/examples/30_wgrad_split_k/CMakeLists.txt b/examples/30_wgrad_split_k/CMakeLists.txt index 3fc5a8954f..98eda79126 100644 --- a/examples/30_wgrad_split_k/CMakeLists.txt +++ b/examples/30_wgrad_split_k/CMakeLists.txt @@ -1,5 +1,5 @@ -# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without diff --git a/examples/31_basic_syrk/CMakeLists.txt b/examples/31_basic_syrk/CMakeLists.txt index e6c077502e..8d5571d237 100644 --- a/examples/31_basic_syrk/CMakeLists.txt +++ b/examples/31_basic_syrk/CMakeLists.txt @@ -1,5 +1,5 @@ -# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without diff --git a/examples/31_basic_syrk/basic_syrk.cu b/examples/31_basic_syrk/basic_syrk.cu index 79b3ab6019..9f9cd93a35 100644 --- a/examples/31_basic_syrk/basic_syrk.cu +++ b/examples/31_basic_syrk/basic_syrk.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -34,7 +34,7 @@ matrix multiply kernel to verify its correctness. The CUTLASS Syrk template is instantiated in the function CutlassSsyrkNN. This is kernel computes - the symmetric rank-k update (SYRK) using double-precision doubleing-point arithmetic and assumes + the symmetric rank-k update (SYRK) using double-precision floating-point arithmetic and assumes all matrices have column-major layout. The threadblock tile size is chosen as 16x32x16 which offers good performance for large matrices. @@ -113,10 +113,10 @@ cudaError_t CutlassSsyrkNN( >, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, 5, // Stages - 1, // AligmentA + 1, // AlignmentA false, // SplitKSerail - cutlass::arch::OpMultiplyAdd, - cutlass::ComplexTransform::kNone, + cutlass::arch::OpMultiplyAdd, + cutlass::ComplexTransform::kNone, cutlass::BlasMode::kSymmetric >; @@ -149,7 +149,7 @@ cudaError_t CutlassSsyrkNN( // // Launch the CUTLASS SYRK kernel. // - + cutlass::Status status = syrk_operator(args); // diff --git a/examples/32_basic_trmm/CMakeLists.txt b/examples/32_basic_trmm/CMakeLists.txt index 0e1afff190..459dbe8f94 100644 --- a/examples/32_basic_trmm/CMakeLists.txt +++ b/examples/32_basic_trmm/CMakeLists.txt @@ -1,5 +1,5 @@ -# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without diff --git a/examples/32_basic_trmm/basic_trmm.cu b/examples/32_basic_trmm/basic_trmm.cu index 988b7a6298..d2eda76a0c 100644 --- a/examples/32_basic_trmm/basic_trmm.cu +++ b/examples/32_basic_trmm/basic_trmm.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -34,7 +34,7 @@ matrix multiply kernel to verify its correctness. The CUTLASS Trmm template is instantiated in the function CutlassStrmmNN. This is kernel computes - the triangular matrix product (TRMM) using double-precision doubleing-point arithmetic and assumes + the triangular matrix product (TRMM) using double-precision floating-point arithmetic and assumes all matrices have column-major layout. The threadblock tile size is chosen as 64x64x16 which offers good performance for large matrices. diff --git a/examples/33_ampere_3xtf32_tensorop_symm/CMakeLists.txt b/examples/33_ampere_3xtf32_tensorop_symm/CMakeLists.txt index 858378da58..115040396e 100644 --- a/examples/33_ampere_3xtf32_tensorop_symm/CMakeLists.txt +++ b/examples/33_ampere_3xtf32_tensorop_symm/CMakeLists.txt @@ -1,5 +1,5 @@ -# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without diff --git a/examples/33_ampere_3xtf32_tensorop_symm/ampere_3xtf32_tensorop_symm.cu b/examples/33_ampere_3xtf32_tensorop_symm/ampere_3xtf32_tensorop_symm.cu index 2e875a91c0..22cb3286eb 100644 --- a/examples/33_ampere_3xtf32_tensorop_symm/ampere_3xtf32_tensorop_symm.cu +++ b/examples/33_ampere_3xtf32_tensorop_symm/ampere_3xtf32_tensorop_symm.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -36,7 +36,7 @@ implicitly to tf32 inside the SYMM kernel which means no change is needed to acc F32 data by using NVIDIA Ampere architecture. We can use the tf32 mode of tensor core to emulate a fast accurate SYMM kernel which is accelerated -using Ampere Tensor Cores (see include/cutlass/gemm/warp/mma_tensor_op_fast_f32.h). +using Ampere Tensor Cores (see include/cutlass/gemm/warp/mma_tensor_op_fast_f32.h). The trick is very simple a x b = (a_big + a_small) x (b_big + b_small) = a_big x b_big + a_big x b_small + a_small x b_big @@ -45,11 +45,11 @@ The trick is very simple a_small x b_small is discarded because they are too small. -This example demonstrates usage of this kernel, along with accuracy measurements w.r.t. actual F32 +This example demonstrates usage of this kernel, along with accuracy measurements w.r.t. actual F32 results (SSYMM from cuBLAS) and against F64 results (DSYMM from CUTLASS) -To enable this feature, the only change needs to make is to change the default OpMultiplyAdd to -OpMultiplyAddFastF32. +To enable this feature, the only change needs to make is to change the default OpMultiplyAdd to +OpMultiplyAddFastF32. Now, we have two different flavors of SSYMM in the profiler for Ampere: @@ -95,7 +95,7 @@ struct Options { float beta; std::string rand_mode; int seed; - + Options(): help(false), problem_size({4096, 4096, 4096}), @@ -137,7 +137,7 @@ struct Options { cmd.get_cmd_line_argument("alpha", alpha); cmd.get_cmd_line_argument("beta", beta); - + cmd.get_cmd_line_argument("seed", seed); cmd.get_cmd_line_argument("rand_mode", rand_mode); @@ -193,7 +193,7 @@ using ShapeMMAWarp = cutlass::gemm::GemmShape<64, 32, 16>; // <- warp tile M = using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 8>; // <- MMA Op tile M = 16, N = 8, K = 8 // This code section describes how threadblocks are scheduled on GPU -using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? +using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // This code section describes the epilogue part of the kernel using EpilogueOp = cutlass::epilogue::thread::LinearCombination< @@ -207,10 +207,10 @@ using EpilogueOp = cutlass::epilogue::thread::LinearCombination< // Number of pipelines you want to use constexpr int NumStages = 3; -// Alignment +// Alignment constexpr int Alignment = 4; -// +// // CUTLASS Symm Operators (SSYM: Symm_3xTF32, Symm_1xTF32, DSYMM: Symm_F64) // @@ -233,7 +233,7 @@ using Symm_3xTF32 = cutlass::gemm::device::Symm< EpilogueOp, SwizzleThreadBlock, NumStages, - 1, // Symmetric matrix is always align 1 + 1, // Symmetric matrix is always align 1 Alignment, false, cutlass::arch::OpMultiplyAddFastF32>; @@ -257,7 +257,7 @@ using Symm_1xTF32 = cutlass::gemm::device::Symm< EpilogueOp, SwizzleThreadBlock, NumStages, - 1, // Symmetric matrix is always align 1 + 1, // Symmetric matrix is always align 1 Alignment, false, cutlass::arch::OpMultiplyAdd>; @@ -298,7 +298,7 @@ bool run(Options &options) { cutlass::HostTensor tensor_a_F32(problem_size.mk()); // <- Create matrix A with dimensions M x K cutlass::HostTensor tensor_b_F32(problem_size.kn()); // <- Create matrix B with dimensions K x N cutlass::HostTensor tensor_c_F32(problem_size.mn()); // <- Create matrix C with dimensions M x N - cutlass::HostTensor tensor_d_F32(problem_size.mn()); // <- Create matrix D with dimensions M x N + cutlass::HostTensor tensor_d_F32(problem_size.mn()); // <- Create matrix D with dimensions M x N if (options.rand_mode == "uniform") { const float min = -1; @@ -339,7 +339,7 @@ bool run(Options &options) { } cutlass::reference::host::TensorFill( tensor_d_F32.host_view()); // <- fill matrix D on host with zeros - + // Copy data from host to GPU tensor_a_F32.sync_device(); tensor_b_F32.sync_device(); @@ -353,7 +353,7 @@ bool run(Options &options) { cutlass::HostTensor tensor_a_F64(problem_size.mk()); // <- Create matrix A with dimensions M x K cutlass::HostTensor tensor_b_F64(problem_size.kn()); // <- Create matrix B with dimensions K x N cutlass::HostTensor tensor_c_F64(problem_size.mn()); // <- Create matrix C with dimensions M x N - + // Symm output (D) for SYMM_3xTF32 cutlass::HostTensor tensor_d_3xTF32(problem_size.mn()); // <- Create matrix D with dimensions M x N // Symm output (D) for SYMM_1xTF32 @@ -375,7 +375,7 @@ bool run(Options &options) { #if CUTLASS_ENABLE_CUBLAS cutlass::reference::host::TensorCopy(tensor_d_cublasF32.host_view(), tensor_d_F32.host_view()); #endif - + // Copy data from host to GPU tensor_a_F64.sync_device(); tensor_b_F64.sync_device(); @@ -430,7 +430,7 @@ bool run(Options &options) { // Instantiate CUTLASS kernel depending on templates Symm_3xTF32 symm_op_3xtf32; - // Check the problem size is supported or not + // Check the problem size is supported or not cutlass::Status status_3xtf32 = symm_op_3xtf32.can_implement(arguments_3xtf32); CUTLASS_CHECK(status_3xtf32); @@ -477,7 +477,7 @@ bool run(Options &options) { // Instantiate CUTLASS kernel depending on templates Symm_1xTF32 symm_op_1xtf32; - // Check the problem size is supported or not + // Check the problem size is supported or not cutlass::Status status_1xtf32 = symm_op_1xtf32.can_implement(arguments_1xtf32); CUTLASS_CHECK(status_1xtf32); @@ -524,7 +524,7 @@ bool run(Options &options) { // Instantiate CUTLASS kernel depending on templates Symm_F64 symm_op_f64; - // Check the problem size is supported or not + // Check the problem size is supported or not cutlass::Status status_f64 = symm_op_f64.can_implement(arguments_f64); CUTLASS_CHECK(status_f64); @@ -568,7 +568,7 @@ bool run(Options &options) { static_cast(&beta), static_cast(tensor_d_cublasF32.device_data()), int(tensor_d_cublasF32.layout().stride(0)) - ); + ); cudaDeviceSynchronize(); @@ -576,7 +576,7 @@ bool run(Options &options) { #endif //////////////////////////////////////////////////////////////////////////////// - /// 7. Compute l2 norms + /// 7. Compute l2 norms //////////////////////////////////////////////////////////////////////////////// #if CUTLASS_ENABLE_CUBLAS @@ -605,20 +605,20 @@ bool run(Options &options) { double l2_norm_3xtf32_vs_cublasf32 = cutlass::reference::host::TensorRelativeErrorMetric( tensor_d_3xTF32.host_view(), tensor_d_cublasF32.host_view()); #endif - + // l2 norm 3xTF32 vs 1xTF32 double l2_norm_3xtf32_vs_1xtf32 = cutlass::reference::host::TensorRelativeErrorMetric( tensor_d_3xTF32.host_view(), tensor_d_1xTF32.host_view()); /////////////////////////////////////////////////////////////////////////////// - // Print kernel info and L2 norms + // Print kernel info and L2 norms std::cout << "Problem Size: (" << problem_size.m() << "," << problem_size.n() << "," << problem_size.k() << ") " << "Alpha: " << alpha << "," << " Beta: " << beta << std::endl; std::cout << std::fixed; std::cout << "Normalized L2 norm of" << std::endl; std::cout.precision(8); - std::cout << std::scientific + std::cout << std::scientific #if CUTLASS_ENABLE_CUBLAS << " - cuBLAS F32 error with F64 reference : " << l2_norm_cublasf32_vs_f64 << std::endl #endif @@ -633,11 +633,11 @@ bool run(Options &options) { } int main(int argc, const char **argv) { - + bool notSupported = false; // Ampere Tensor Core operations exposed with mma.sync and ldmatrix are first available - // in CUDA 11.0. + // in CUDA 11.0. // // CUTLASS must be compiled with CUDA 11.0 Toolkit to run these examples. if (!(__CUDACC_VER_MAJOR__ >= 11)) { @@ -650,7 +650,7 @@ int main(int argc, const char **argv) { cudaError_t error = cudaGetDeviceProperties(&props, 0); if (error != cudaSuccess) { std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; - return false; + return -1; } if (!((props.major * 10 + props.minor) >= 80)) { diff --git a/examples/34_transposed_conv2d/34_transposed_conv2d.cu b/examples/34_transposed_conv2d/34_transposed_conv2d.cu index d9d878ad27..f3393c7ce5 100644 --- a/examples/34_transposed_conv2d/34_transposed_conv2d.cu +++ b/examples/34_transposed_conv2d/34_transposed_conv2d.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -251,7 +251,7 @@ struct Options { << " --tag= String to replicate across the first column in the results table\n"; out << "\n\nExamples:\n\n" - << "$ ./examples/31_transposed_conv2d/31_transposed_conv2d --n=8 --h=32 --w=32 --c=16 --k=32 --r=3 --s=3\n\n"; + << "$ ./examples/34_transposed_conv2d/34_transposed_conv2d --n=8 --h=32 --w=32 --c=16 --k=32 --r=3 --s=3\n\n"; return out; } @@ -603,7 +603,7 @@ int main(int argc, char const **args) { cudaDeviceProp props; CUDA_CHECK(cudaGetDeviceProperties(&props, 0)); - if (!(props.major > 8 || (props.major == 8 && props.minor >= 0))) { + if (!(props.major >= 8)) { std::cerr << "Ampere Tensor Ops must be run on a machine with compute capability at least 80." << std::endl; notSupported = true; diff --git a/examples/34_transposed_conv2d/CMakeLists.txt b/examples/34_transposed_conv2d/CMakeLists.txt index 00a0dbce73..414b011ac3 100644 --- a/examples/34_transposed_conv2d/CMakeLists.txt +++ b/examples/34_transposed_conv2d/CMakeLists.txt @@ -1,5 +1,5 @@ -# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without diff --git a/examples/35_gemm_softmax/CMakeLists.txt b/examples/35_gemm_softmax/CMakeLists.txt index 51611290bd..b7ecd99fcc 100644 --- a/examples/35_gemm_softmax/CMakeLists.txt +++ b/examples/35_gemm_softmax/CMakeLists.txt @@ -1,5 +1,5 @@ -# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without diff --git a/examples/35_gemm_softmax/gemm_softmax.cu b/examples/35_gemm_softmax/gemm_softmax.cu index 0d18077ee4..731e37b4d9 100644 --- a/examples/35_gemm_softmax/gemm_softmax.cu +++ b/examples/35_gemm_softmax/gemm_softmax.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -42,19 +42,24 @@ #include "cutlass/arch/memory.h" #include "cutlass/arch/memory_sm75.h" #include "cutlass/gemm/device/gemm_complex.h" - +#include "cutlass/numeric_types.h" +#include "cutlass/numeric_size.h" #include "cutlass/util/command_line.h" #include "cutlass/util/host_tensor.h" #include "cutlass/util/reference/host/gemm_complex.h" +#include "cutlass/util/reference/device/gemm_complex.h" #include "cutlass/util/reference/host/tensor_reduce.h" #include "cutlass/util/reference/host/tensor_compare.h" #include "cutlass/util/reference/host/tensor_norm.h" #include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/device/tensor_fill.h" #include "cutlass/util/reference/host/tensor_fill.h" #include "cutlass/util/reference/host/error_metrics.h" #include "cutlass/util/tensor_view_io.h" +#include "cutlass/numeric_size.h" // cutlass::bits_to_bytes +#include "cutlass/layout/matrix.h" #include "cutlass/epilogue/thread/linear_combination.h" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -85,18 +90,18 @@ struct Options { float alpha; float beta; bool verification_enabled; - double tolerance; + float tolerance; Options(): help(false), problem_size({16, 24, 64}), - batch_count(1), // As a temporary limitation to the test bench, batch count must be 1. The kernels support arbitrary batching. + batch_count(16), iterations(20), seed(2022), alpha(1), - beta(), + beta(0), verification_enabled(true), - tolerance(0.01) + tolerance(1e-5f) { } bool valid() { @@ -116,6 +121,8 @@ struct Options { cmd.get_cmd_line_argument("n", problem_size.n()); cmd.get_cmd_line_argument("k", problem_size.k()); + cmd.get_cmd_line_argument("batch_count", batch_count); + cmd.get_cmd_line_argument("alpha", alpha); cmd.get_cmd_line_argument("beta", beta); @@ -135,6 +142,7 @@ struct Options { << " --m= GEMM M dimension\n" << " --n= GEMM N dimension\n" << " --k= GEMM K dimension\n" + << " --batch_count= Batch number\n" << " --alpha= Epilogue scalar alpha\n" << " --beta= Epilogue scalar beta\n\n" << " --seed= Random number seed (1*)\n\n" @@ -198,13 +206,28 @@ struct Testbed { using ElementA = cutlass::half_t; using ElementB = cutlass::half_t; using ElementC = cutlass::half_t; - using ElementD = cutlass::half_t; using ElementCompute = float; - using ElementSoftmax = cutlass::half_t; + using ElementD = ElementC; + using ElementSoftmax = ElementC; using LayoutA = cutlass::layout::RowMajor; using LayoutB = cutlass::layout::ColumnMajor; + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + using OperatorClass = cutlass::arch::OpClassTensorOp; + using ArchTag = cutlass::arch::Sm80; + + // ApplyShape impacts the final Softmax performance a lot. + // Set ApplyShape::kColumn to be the next multiple of 32 number that is after + // (gemm_N / alignment). + // Set ApplyShape::kRow to max(1, 128 / ApplyShape::kColumn). + using ApplyShape = cutlass::MatrixShape<1, 1024>; + + static int const kStages = 3; + /// Linear scaling operator using EpilogueFunctorOp = cutlass::epilogue::thread::LinearCombination< ElementC, @@ -218,12 +241,22 @@ struct Testbed { ElementB, LayoutB, ElementC, ElementCompute, - EpilogueFunctorOp + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueFunctorOp, + kStages, + ApplyShape >; using ElementNorm = typename GemmSoftmax::ElementNorm; using ElementSum = typename GemmSoftmax::ElementSum; using LayoutC = typename GemmSoftmax::LayoutC; + using LayoutN = typename GemmSoftmax::LayoutN; + using LayoutS = typename GemmSoftmax::LayoutS; + using MatrixCoord = typename LayoutC::TensorCoord; // // Data members @@ -231,20 +264,42 @@ struct Testbed { Options const &options; - cutlass::HostTensor tensor_A; - cutlass::HostTensor tensor_B; - cutlass::HostTensor tensor_C; - cutlass::HostTensor tensor_D; - cutlass::HostTensor tensor_N; - cutlass::HostTensor tensor_S; - cutlass::HostTensor tensor_Softmax; - cutlass::HostTensor reference_D; cutlass::HostTensor reference_N; - cutlass::HostTensor reference_Softmax; + + cutlass::DeviceAllocation block_A; + cutlass::DeviceAllocation block_B; + cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_D; + cutlass::DeviceAllocation block_Ref; + cutlass::DeviceAllocation block_Softmax; + cutlass::DeviceAllocation block_Norm; + cutlass::DeviceAllocation block_Sum; int block_num = (options.problem_size.n() + GemmSoftmax::ThreadblockShape::kN - 1) / GemmSoftmax::ThreadblockShape::kN; + cutlass::gemm::GemmCoord problem = options.problem_size; + + int64_t lda = LayoutA::packed({problem.m(), problem.k()}).stride(0); + int64_t ldb = LayoutB::packed({problem.k(), problem.n()}).stride(0); + int64_t ldc = LayoutC::packed({problem.m(), problem.n()}).stride(0); + + // fixed rowmajor for norm and sum + int64_t ldn = problem.m(); + int64_t lds = ldn; + + int64_t total_elements_A_per_batch = problem.m() * problem.k(); + int64_t total_elements_B_per_batch = problem.k() * problem.n(); + int64_t total_elements_C_per_batch = problem.m() * problem.n(); + int64_t total_elements_D_per_batch = problem.m() * problem.n(); + int64_t total_elements_partial_norm_per_batch = block_num * problem.m(); + + int64_t total_elements_A = total_elements_A_per_batch * options.batch_count; + int64_t total_elements_B = total_elements_B_per_batch * options.batch_count; + int64_t total_elements_C = total_elements_C_per_batch * options.batch_count; + int64_t total_elements_D = total_elements_D_per_batch * options.batch_count; + int64_t total_elements_partial_norm = total_elements_partial_norm_per_batch * options.batch_count; + // // Methods // @@ -254,20 +309,7 @@ struct Testbed { ): options(options_) { - - tensor_A.reset({options.problem_size.m(), options.problem_size.k()}); - tensor_B.reset({options.problem_size.k(), options.problem_size.n()}); - - tensor_C.reset({options.problem_size.m(), options.problem_size.n()}); - tensor_D.reset({options.problem_size.m(), options.problem_size.n()}); - - tensor_N.reset({block_num, options.problem_size.m()}); - tensor_S.reset({block_num, options.problem_size.m()}); - tensor_Softmax.reset({options.problem_size.m(), options.problem_size.n()}); - - reference_D.reset({options.problem_size.m(), options.problem_size.n()}, false); reference_N.reset({options.problem_size.m(), 1}, false); - reference_Softmax.reset({options.problem_size.m(), options.problem_size.n()}, false); } /// Run @@ -300,11 +342,6 @@ struct Testbed { return disposition; } - // - // Compute the reference - // - compute_reference(); - // // Verify // @@ -334,43 +371,38 @@ struct Testbed { /// Random initialization void initialize() { - cutlass::reference::host::TensorFillRandomUniform( - tensor_A.host_view(), - options.seed, - ElementD(5), - ElementD(-5), - 0 - ); + block_A.reset(total_elements_A); + block_B.reset(total_elements_B); + block_C.reset(total_elements_C); + block_D.reset(total_elements_D); + block_Softmax.reset(total_elements_D); + block_Ref.reset(total_elements_D_per_batch); + block_Norm.reset(total_elements_partial_norm); + block_Sum.reset(total_elements_partial_norm); - cutlass::reference::host::TensorFillRandomUniform( - tensor_B.host_view(), - options.seed + 19, - ElementD(5), - ElementD(-5), - 0 - ); + cutlass::reference::device::BlockFillRandomUniform( + block_A.get(), total_elements_A, options.seed, ElementA(5), ElementA(-5), 0); - cutlass::reference::host::TensorFill( - reference_D.host_view(), - ElementD() - ); + cutlass::reference::device::BlockFillRandomUniform( + block_B.get(), total_elements_B, options.seed + 1, ElementB(5), ElementB(-5), 0); + + cutlass::reference::device::BlockFillRandomUniform( + block_C.get(), total_elements_C, options.seed + 2, ElementC(5), ElementC(-5), 0); + + cutlass::reference::device::BlockFillRandomUniform( + block_D.get(), total_elements_D, options.seed + 3, ElementD(5), ElementD(-5), 0); + + cutlass::reference::device::BlockFillRandomUniform( + block_Ref.get(), total_elements_D_per_batch, options.seed + 3, ElementD(5), ElementD(-5), 0); + + cutlass::reference::device::BlockFillRandomUniform( + block_Softmax.get(), total_elements_D, options.seed + 3, ElementSoftmax(5), ElementSoftmax(-5), 0); cutlass::reference::host::TensorFill( reference_N.host_view(), ElementNorm() ); - cutlass::reference::host::TensorFill( - reference_Softmax.host_view(), - ElementSoftmax() - ); - - tensor_A.sync_device(); - tensor_B.sync_device(); - tensor_D.sync_device(); - tensor_N.sync_device(); - tensor_S.sync_device(); - tensor_Softmax.sync_device(); } cutlass::Status execute_device_kernel() { @@ -384,17 +416,24 @@ struct Testbed { GemmSoftmax::Arguments args( options.problem_size, options.batch_count, - tensor_A.device_ref(), - tensor_B.device_ref(), - tensor_C.device_ref(), - tensor_D.device_ref(), + {block_A.get(), lda}, + {block_B.get(), ldb}, + {block_C.get(), ldc}, + {block_D.get(), ldc}, { ElementCompute(options.alpha), ElementCompute(options.beta) }, - tensor_N.device_ref(), - tensor_S.device_ref(), - tensor_Softmax.device_ref() + {block_Norm.get(), ldn}, + {block_Sum.get(), lds}, + {block_Softmax.get(), ldc}, + total_elements_A_per_batch, + total_elements_B_per_batch, + total_elements_C_per_batch, + total_elements_D_per_batch, + total_elements_partial_norm_per_batch, + total_elements_partial_norm_per_batch, + total_elements_D_per_batch ); // @@ -415,68 +454,21 @@ struct Testbed { return status; } - /// Reference calculation - void compute_reference() { - - // Compute GEMM - - cutlass::reference::host::GemmComplex( - options.problem_size, - options.alpha, - tensor_A.host_ref(), - cutlass::ComplexTransform::kNone, - tensor_B.host_ref(), - cutlass::ComplexTransform::kNone, - options.beta, - tensor_C.host_ref(), - reference_D.host_ref(), - double() - ); - - // Compute the norm - for (int m = 0; m < options.problem_size.m(); ++m) { - reference_N.at({m, 0}) = reference_D.at({m, 0}); - for (int n = 1; n < options.problem_size.n(); ++n) { - reference_N.at({m, 0}) = std::max(reference_N.at({m, 0}), ElementNorm(reference_D.at({m, n}))); - } - } - - // Compute softmax - for (int m = 0; m < options.problem_size.m(); ++m) { - - float sum = float(); - - for (int n = 0; n < options.problem_size.n(); ++n) { - sum += std::exp( float(reference_D.at({m, n})) - float(reference_N.at({m, 0})) ); - } - - float inv_sum = float(1.0f / sum); - - for (int n = 0; n < options.problem_size.n(); ++n) { - - reference_Softmax.at({m, n}) = ElementSoftmax( - std::exp( float(reference_D.at({m, n})) - float(reference_N.at({m, 0})) ) * inv_sum - ); - } - } - } - - /// Emits all tensor values - void emit_results() { - std::cout << "D = \n" << tensor_D.host_view() << "\n\n"; - std::cout << "N = \n" << tensor_N.host_view() << "\n\n"; - std::cout << "Softmax = \n" << tensor_Softmax.host_view() << "\n\n"; - std::cout << "Reference N = \n" << reference_N.host_view() << "\n\n"; - std::cout << "Reference D = \n" << reference_D.host_view() << "\n\n"; - std::cout << "Reference Softmax = \n" << reference_Softmax.host_view() << "\n\n"; - } - - bool verify_tensor_N(cutlass::HostTensor tensor_N, \ - cutlass::HostTensor reference_N) { - - for (int m = 0; m < options.problem_size.m(); ++m) { - float diff = (float)(tensor_N.at({0, m}) - reference_N.at({m, 0})); - if (fabs(diff) > options.tolerance) { + template + bool verify_tensor(std::vector vector_Input, \ + std::vector vector_Input_Ref) { + + auto size = int64_t((vector_Input.size() < vector_Input_Ref.size()) ? vector_Input.size() : vector_Input_Ref.size()); + float abs_tol = options.tolerance; + float rel_tol = options.tolerance; + + for (int64_t i = 0; i < size; ++i) { + float diff = (float)(vector_Input.at(i) - vector_Input_Ref.at(i)); + float abs_diff = fabs(diff); + float abs_ref = fabs((float)vector_Input_Ref.at(i)); + float relative_diff = abs_ref > abs_tol ? abs_diff / abs_ref : 0; + if ( (isnan(abs_diff) || isinf(abs_diff)) || (abs_diff > rel_tol && relative_diff > rel_tol)) { + printf("diff = %f, {%f, %f}.\n", abs_diff, (float)(vector_Input.at(i)), (float)(vector_Input_Ref.at(i))); return false; } @@ -488,80 +480,112 @@ struct Testbed { /// Verifies the reference matches bool verify() { - tensor_D.sync_host(); - tensor_N.sync_host(); - tensor_Softmax.sync_host(); - - double const kThreshold = options.tolerance; - - // Verification checks - set any of these to 'true' to override the verification checks. - bool verified_D = false; - bool verified_N = false; - bool verified_Softmax = false; + LayoutA layout_A(lda); + LayoutB layout_B(ldb); + LayoutC layout_C(ldc); + LayoutN Layout_N(ldn); + LayoutS Layout_S(lds); + + MatrixCoord extent_A{problem.m(), problem.k()}; + MatrixCoord extent_B{problem.k(), problem.n()}; + MatrixCoord extent_C{problem.m(), problem.n()}; + + for (int batch_idx = 0; batch_idx < options.batch_count; batch_idx++) { + + cutlass::TensorView view_A(block_A.get() + total_elements_A_per_batch * batch_idx, layout_A, extent_A); + cutlass::TensorView view_B(block_B.get() + total_elements_B_per_batch * batch_idx, layout_B, extent_B); + cutlass::TensorView view_C(block_C.get() + total_elements_C_per_batch * batch_idx, layout_C, extent_C); + cutlass::TensorView view_Ref_device(block_Ref.get(), layout_C, extent_C); + + cutlass::reference::device::GemmComplex< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, ElementCompute + >( + problem, + options.alpha, + view_A, + cutlass::ComplexTransform::kNone, + view_B, + cutlass::ComplexTransform::kNone, + options.beta, + view_C, + view_Ref_device, + ElementCompute(0) + ); - // Verify softmax output - if (!verified_D) { + // Copy reference results to host memory for verification + std::vector matrix_D_Ref(layout_C.capacity(extent_C)); + cutlass::device_memory::copy_to_host(matrix_D_Ref.data(), block_Ref.get(), matrix_D_Ref.size()); + cutlass::TensorView view_Ref(matrix_D_Ref.data(), layout_C, extent_C); - double norm_diff = cutlass::reference::host::TensorNormDiff( - tensor_D.host_view(), - reference_D.host_view()); + std::vector matrix_Softmax_Ref(layout_C.capacity(extent_C)); + cutlass::TensorView view_Softmax_Ref(matrix_Softmax_Ref.data(), layout_C, extent_C); - double norm_reference = cutlass::reference::host::TensorNorm( - reference_D.host_view()); + // Copy computed results to host memory + std::vector matrix_D(layout_C.capacity(extent_C)); + cutlass::device_memory::copy_to_host(matrix_D.data(), block_D.get() + total_elements_D_per_batch * batch_idx, matrix_D.size()); - double rel_error = norm_diff / norm_reference; + std::vector matrix_Softmax(layout_C.capacity(extent_C)); + cutlass::device_memory::copy_to_host(matrix_Softmax.data(), block_Softmax.get() + total_elements_D_per_batch * batch_idx, matrix_Softmax.size()); - if (rel_error > kThreshold) { - std::cerr << "\n\nTensor D Relative error: " << rel_error << std::endl; - } - else { - verified_D = true; + // Compute the norm + for (int m = 0; m < options.problem_size.m(); ++m) { + reference_N.at({m, 0}) = view_Ref.ref().at({m, 0}); + for (int n = 1; n < options.problem_size.n(); ++n) { + reference_N.at({m, 0}) = std::max(reference_N.at({m, 0}), ElementNorm(view_Ref.ref().at({m, n}))); + } } - } - if (!verified_N) { - verified_N = verify_tensor_N(tensor_N, reference_N); - } + // Compute softmax + for (int m = 0; m < options.problem_size.m(); ++m) { - if (!verified_Softmax) { + float sum = float(); - double norm_diff = cutlass::reference::host::TensorNormDiff( - tensor_Softmax.host_view(), - reference_Softmax.host_view()); + for (int n = 0; n < options.problem_size.n(); ++n) { + sum += std::exp( float(view_Ref.ref().at({m, n})) - float(reference_N.at({m, 0})) ); + } - double norm_reference = cutlass::reference::host::TensorNorm( - reference_Softmax.host_view()); + float inv_sum = float(1.0f / sum); - double rel_error = norm_diff / norm_reference; + for (int n = 0; n < options.problem_size.n(); ++n) { - if (rel_error > kThreshold) { - std::cerr << "\n\nSoftmax Relative error: " << rel_error << std::endl; + view_Softmax_Ref.ref().at({m, n}) = ElementSoftmax( + std::exp( float(view_Ref.ref().at({m, n})) - float(reference_N.at({m, 0})) ) * inv_sum + ); + } } - else { - verified_Softmax = true; - } - } - if (!verified_D || !verified_N || !verified_Softmax) { + // Verification checks - set any of these to 'true' to override the verification checks. + bool verified_D = false; + bool verified_Softmax = false; - std::cerr << "Verification check failed for tensor Softmax" << std::endl; - - emit_results(); - - // Summarize which checks failed + // Verify softmax output if (!verified_D) { - std::cerr << "Verification of D tensor failed\n"; + verified_D = verify_tensor(matrix_D, matrix_D_Ref); } - if (!verified_N) { - std::cerr << "Verification of N tensor failed\n"; + if (!verified_Softmax) { + verified_Softmax = verify_tensor(matrix_Softmax, matrix_Softmax_Ref); } - if (!verified_Softmax) { - std::cerr << "Verification of Softmax tensor failed\n"; + if (!verified_D || !verified_Softmax) { + + std::cerr << "Verification check failed for tensor Softmax at batch " << batch_idx << "\n"; + + // Summarize which checks failed + if (!verified_D) { + std::cerr << "Verification of D tensor failed\n"; + } + + if (!verified_Softmax) { + std::cerr << "Verification of Softmax tensor failed\n"; + } + + return false; } - return false; } return true; @@ -635,16 +659,21 @@ struct Testbed { } int64_t flops = int64_t(options.problem_size.m()) * options.problem_size.n() * options.problem_size.k() * 2; - int64_t bytes = (sizeof(ElementD) * 2 + sizeof(ElementSoftmax)) * options.problem_size.m() * options.problem_size.n(); + int64_t bytes = cutlass::bits_to_bytes( + (cutlass::sizeof_bits::value * 2 + cutlass::sizeof_bits::value) * + options.problem_size.m() * options.problem_size.n()); + + double gflops_per_second = double(flops) * kIterations * options.batch_count / double(elapsed_ms / 1000.0f) / double(1.0e9); + double gbytes_per_second = double(bytes) * kIterations * options.batch_count / double(elapsed_ms / 1000.0f) / double(1 << 30); - double gflops_per_second = double(flops) * kIterations / double(elapsed_ms / 1000.0f) / double(1.0e9); - double gbytes_per_second = double(bytes) * kIterations / double(elapsed_ms / 1000.0f) / double(1 << 30); + double elapsed_ms_per_iter = double(elapsed_ms) / kIterations; std::cout << " Problem: " << options.problem_size.m() << "-by-" << options.problem_size.n() << "-by-" << options.problem_size.k() + << ", batch size: " << options.batch_count << std::endl; - std::cout << " Runtime: " << elapsed_ms << " ms\n" << std::endl; + std::cout << " Runtime: " << elapsed_ms_per_iter << " ms\n" << std::endl; std::cout << " GFLOPs: " << gflops_per_second << " GFLOPs" << std::endl; std::cout << "Memory bandwidth: " << gbytes_per_second << " GiB/s" << std::endl; @@ -692,6 +721,4 @@ int main(int argc, const char **argv) { return (disposition == Disposition::kPassed ? 0 : -1); } - ///////////////////////////////////////////////////////////////////////////////////////////////// - diff --git a/examples/35_gemm_softmax/gemm_with_epilogue_visitor.h b/examples/35_gemm_softmax/gemm_with_epilogue_visitor.h index 814de5ae7f..43208150d6 100644 --- a/examples/35_gemm_softmax/gemm_with_epilogue_visitor.h +++ b/examples/35_gemm_softmax/gemm_with_epilogue_visitor.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -29,7 +29,8 @@ * **************************************************************************************************/ /*! \file - \brief GEMM kernel to support the 'epilogue visitor' model for fusion. + \brief GEMM kernel to support the epilogue visitor model + for customized softmax partial reduction epilogue fusion. This source file will likely be moved to `include/cutlass/gemm/kernel/` in the future once its usage has been stabilized. For now, it is included in this example to demonstrate @@ -78,6 +79,7 @@ struct GemmWithEpilogueVisitor { using ElementC = typename EpilogueVisitor::ElementOutput; using LayoutC = typename Epilogue::Layout; + using TensorRefC = TensorRef; static ComplexTransform const kTransformA = Mma::kTransformA; static ComplexTransform const kTransformB = Mma::kTransformB; @@ -89,6 +91,9 @@ struct GemmWithEpilogueVisitor { using InstructionShape = typename Mma::Policy::Operator::InstructionShape; using ArchTag = typename Mma::ArchTag; + using ElementNorm = typename EpilogueVisitor::ElementNorm; + using ElementSum = typename EpilogueVisitor::ElementSum; + static int const kStages = Mma::kStages; static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; @@ -121,6 +126,11 @@ struct GemmWithEpilogueVisitor { TensorRefA ref_A; TensorRefB ref_B; + TensorRefC ref_C; + TensorRefC ref_D; + + ElementNorm *ptr_Max; + ElementSum *ptr_Sum; int64_t batch_stride_A; int64_t batch_stride_B; @@ -144,6 +154,10 @@ struct GemmWithEpilogueVisitor { int batch_count_, TensorRefA ref_A_, TensorRefB ref_B_, + TensorRefC ref_C_, + TensorRefC ref_D_, + ElementNorm *ptr_Max_, + ElementSum *ptr_Sum_, int64_t batch_stride_A_, int64_t batch_stride_B_, typename EpilogueVisitor::Arguments epilogue_visitor_ @@ -153,6 +167,10 @@ struct GemmWithEpilogueVisitor { batch_count(batch_count_), ref_A(ref_A_), ref_B(ref_B_), + ref_C(ref_C_), + ref_D(ref_D_), + ptr_Max(ptr_Max_), + ptr_Sum(ptr_Sum_), batch_stride_A(batch_stride_A_), batch_stride_B(batch_stride_B_), epilogue_visitor(epilogue_visitor_) @@ -174,6 +192,8 @@ struct GemmWithEpilogueVisitor { typename Mma::IteratorA::Params params_A; typename Mma::IteratorB::Params params_B; + typename EpilogueVisitor::OutputTileIterator::Params params_C; + typename EpilogueVisitor::OutputTileIterator::Params params_D; GemmUniversalMode mode; int batch_count; @@ -181,6 +201,11 @@ struct GemmWithEpilogueVisitor { void * ptr_A; void * ptr_B; + ElementC * ptr_C; + ElementC * ptr_D; + + ElementNorm * ptr_Max; + ElementSum * ptr_Sum; int64_t batch_stride_A; int64_t batch_stride_B; @@ -196,11 +221,17 @@ struct GemmWithEpilogueVisitor { swizzle_log_tile(0), params_A(0), params_B(0), + params_C(0), + params_D(0), batch_count(0), gemm_k_size(0), mode(cutlass::gemm::GemmUniversalMode::kGemm), ptr_A(nullptr), ptr_B(nullptr), + ptr_C(nullptr), + ptr_D(nullptr), + ptr_Max(nullptr), + ptr_Sum(nullptr), batch_stride_A(0), batch_stride_B(0) { } @@ -213,11 +244,17 @@ struct GemmWithEpilogueVisitor { swizzle_log_tile(0), params_A(args.ref_A.layout()), params_B(args.ref_B.layout()), + params_C(args.ref_C.layout()), + params_D(args.ref_D.layout()), mode(args.mode), batch_count(args.batch_count), gemm_k_size(args.problem_size.k()), ptr_A(args.ref_A.data()), ptr_B(args.ref_B.data()), + ptr_C(args.ref_C.data()), + ptr_D(args.ref_D.data()), + ptr_Max(args.ptr_Max), + ptr_Sum(args.ptr_Sum), batch_stride_A(args.batch_stride_A), batch_stride_B(args.batch_stride_B), epilogue_visitor(args.epilogue_visitor) @@ -330,12 +367,6 @@ struct GemmWithEpilogueVisitor { return can_implement(args.problem_size); } - static size_t get_extra_workspace_size(Arguments const &args, - cutlass::gemm::GemmCoord const &grid_tiled_shape) { - - return 0; - } - #define SPLIT_K_ENABLED 1 /// Executes one GEMM @@ -467,7 +498,14 @@ struct GemmWithEpilogueVisitor { thread_idx, warp_idx, lane_idx, - threadblock_offset); + params.params_C, + params.params_D, + params.ptr_C, + params.ptr_D, + params.ptr_Max, + params.ptr_Sum, + threadblock_offset, + blockIdx.y *params.problem_size.m() ); if (params.mode == GemmUniversalMode::kGemm) { // Indicate which position in a serial reduction the output operator is currently updating diff --git a/examples/35_gemm_softmax/gemm_with_softmax.h b/examples/35_gemm_softmax/gemm_with_softmax.h index 213f8c5a44..748905d9f8 100644 --- a/examples/35_gemm_softmax/gemm_with_softmax.h +++ b/examples/35_gemm_softmax/gemm_with_softmax.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -49,10 +49,12 @@ #include "cutlass/gemm/kernel/default_gemm.h" #include "cutlass/gemm/kernel/default_gemm_complex.h" #include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/epilogue/threadblock/epilogue_visitor_with_softmax.h" +#include "cutlass/epilogue/threadblock/epilogue_with_visitor.h" +#include "cutlass/reduction/kernel/reduce_softmax_final.h" ///////////////////////////////////////////////////////////////////////////////////////////////// -#include "epilogue_with_visitor.h" #include "gemm_with_epilogue_visitor.h" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -77,7 +79,7 @@ template < typename ElementSoft_, typename ElementSoftmaxCompute_, int Alignment, - typename Shape_ = MatrixShape<4, 16> + typename ApplyShape_ = MatrixShape<1, 1024> > class ApplySoftmax { public: @@ -89,7 +91,7 @@ class ApplySoftmax { using ElementSoftmaxCompute = ElementSoftmaxCompute_; static int const kAlignment = Alignment; - using Shape = Shape_; + using ApplyShape = ApplyShape_; using Layout = cutlass::layout::RowMajor; @@ -200,7 +202,7 @@ class ApplySoftmax { using AccessTypeD = AlignedArray; int block_batch = blockIdx.z; - int block_m = blockIdx.x * Shape::kRow; + int block_m = blockIdx.x * ApplyShape::kRow; int block_n = 0; int thread_m = threadIdx.y; @@ -209,6 +211,9 @@ class ApplySoftmax { int idx_m = block_m + thread_m; int idx_n = block_n + thread_n; + int batch_offset_norm = block_batch * params.args.batch_stride_N; + int batch_offset_sum = block_batch * params.args.batch_stride_S; + // Kill off thread if it is outside the row boundary if (params.args.extent.row() <= idx_m) { return; @@ -251,8 +256,8 @@ class ApplySoftmax { params.args.batch_stride_Soft * block_batch + params.args.ref_Soft.layout()({idx_m, idx_n})); - ElementSum inv_sum = (params.args.ref_S.data())[block_m]; - ElementNorm norm = (params.args.ref_N.data())[block_m]; + ElementSum inv_sum = (params.args.ref_S.data())[idx_m + batch_offset_sum]; + ElementNorm norm = (params.args.ref_N.data())[idx_m + batch_offset_norm]; // // Loop @@ -261,10 +266,9 @@ class ApplySoftmax { for ( int idx = 0; idx < params.args.extent.column(); - idx += Shape::kColumn * kAlignment) { + idx += ApplyShape::kColumn * kAlignment) { if (idx_n < params.args.extent.column()) { - AccessTypeD fetch; arch::global_load(fetch, access_d, true); @@ -274,561 +278,10 @@ class ApplySoftmax { arch::global_store(soft, access_soft, true); } - access_d += Shape::kColumn; - access_soft += Shape::kColumn; - idx_n += Shape::kColumn * kAlignment; - } - } -}; - -template < - typename ElementNorm_, - typename ElementSum_, - typename ElementSoftmaxCompute_, - typename ThreadblockShape_ -> -class ApplyFinalReduction { -public: - - using ElementNorm = ElementNorm_; - using ElementSum = ElementSum_; - using ElementSoftmaxCompute = ElementSoftmaxCompute_; - using ThreadblockShape = ThreadblockShape_; - - using Layout = cutlass::layout::RowMajor; - - using TensorRefN = TensorRef; - using TensorRefSum = TensorRef; - - // - // Arguments - // - - struct Arguments { - - MatrixCoord extent; ///< Extent of D and Softmax matrices - int batch_count; ///< Batch count - TensorRefN ref_N; ///< Norm tensor (input / output) - TensorRefSum ref_Sum; ///< Sum tensor (input / output) - int64_t batch_stride_N; ///< Batch stride for N tensor - int64_t batch_stride_Sum; ///< Batch stride for softmax tensor - - // - // Methods - // - Arguments(): - batch_count(1), - batch_stride_N(0), - batch_stride_Sum(0) - { } - - Arguments( - MatrixCoord extent_, ///< Extent of D and Softmax matrices - int batch_count_, ///< Batch count - TensorRefN ref_N_, ///< Output parameter for N - TensorRefSum ref_Sum_ , ///< Sum - int64_t batch_stride_N_ = 0, - int64_t batch_stride_Sum_ = 0 - ): - extent(extent_), - batch_count(batch_count_), - ref_N(ref_N_), - ref_Sum(ref_Sum_), - batch_stride_N(batch_stride_N_), - batch_stride_Sum(batch_stride_Sum_) - { - - } - }; - - struct SharedStorage { - - - }; - - // - // Params struct - // - - struct Params { - Arguments args; - - // - // Methods - // - Params() { } - - Params(Arguments const &args_): args(args_) { } - }; - -private: - -public: - - CUTLASS_DEVICE - ApplyFinalReduction() { } - - CUTLASS_DEVICE - void operator()(Params const ¶ms, SharedStorage &shared_storage) { - - apply(params, shared_storage); - } - -private: - - /// Partial reduction - CUTLASS_DEVICE - void apply(Params const ¶ms, SharedStorage &shared_storage) { - - int threadblock_num = (params.args.extent.column() + ThreadblockShape::kN - 1) / ThreadblockShape::kN; - - int block_batch = blockIdx.z; - - int block_n = blockIdx.x * blockDim.x; - - int thread_n = threadIdx.x; - - int idx_n = block_n + thread_n; - - if (idx_n >= params.args.extent.row()) { - return; - } - - - using ConvertSumOutput = cutlass::NumericConverter; - using ConvertNormOutput = cutlass::NumericConverter; - - using ConvertSum = cutlass::NumericConverter; - using ConvertNorm = cutlass::NumericConverter; - - ConvertSum convert_sum; - ConvertNorm convert_norm; - - ConvertSumOutput convert_sum_output; - ConvertNormOutput convert_norm_output; - - ElementNorm *access_n = params.args.ref_N.data() + params.args.batch_stride_N * block_batch + idx_n; - ElementSum *access_s = params.args.ref_Sum.data() + params.args.batch_stride_Sum * block_batch + idx_n; - - ElementNorm *access_n_bak = access_n; - ElementSum *access_s_bak = access_s; - - uint32_t float_max_bits = 0xff7fffff; - float min_float = reinterpret_cast(float_max_bits); - - ElementSoftmaxCompute max_val = ElementSoftmaxCompute(min_float); - ElementSoftmaxCompute sum_val = ElementSoftmaxCompute(0); - ElementNorm fetch_n; - ElementSum fetch_s; - - CUTLASS_PRAGMA_UNROLL - for (int idx_m = 0; idx_m < threadblock_num; idx_m++) { - arch::global_load(fetch_n, access_n, true); - max_val = fast_max(max_val, convert_norm(fetch_n)); - access_n += params.args.extent.row(); - } - - access_n = access_n_bak; - - CUTLASS_PRAGMA_UNROLL - for (int idx_m = 0; idx_m < threadblock_num; idx_m++) { - arch::global_load(fetch_n, access_n, true); - arch::global_load(fetch_s, access_s, true); - sum_val += convert_sum(fetch_s) * fast_exp(convert_norm(fetch_n) - max_val); - access_n += params.args.extent.row(); - access_s += params.args.extent.row(); - } - - ElementSoftmaxCompute inv_sum = cutlass::constants::one() / sum_val; - - access_n = access_n_bak; - access_s = access_s_bak; - - access_n[0] = convert_norm_output(max_val); - access_s[0] = convert_sum_output(inv_sum); - } -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// - -template < - typename ThreadblockShape_, - int ThreadCount, - typename OutputTileIterator_, - typename ElementAccumulator_, - typename ElementNorm_, - typename ElementSum_, - typename ElementSoftmaxCompute_, - typename ElementwiseFunctor_ -> -class EpilogueVisitorBiasMax { -public: - - using ThreadblockShape = ThreadblockShape_; - static int const kThreadCount = ThreadCount; - - using OutputTileIterator = OutputTileIterator_; - using ElementwiseFunctor = ElementwiseFunctor_; - - static int const kIterations = OutputTileIterator::kIterations; - static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; - - using ElementOutput = typename OutputTileIterator::Element; - using LayoutOutput = cutlass::layout::RowMajor; - using ElementAccumulator = ElementAccumulator_; - - using ElementNorm = ElementNorm_; - using ElementSum = ElementSum_; - using ElementSoftmaxCompute = ElementSoftmaxCompute_; - - using AccumulatorFragment = Array; - using SoftmaxFragment = Array; - using OutputVector = Array; - using TensorRefD = TensorRef; - - /// Argument structure - struct Arguments { - - typename ElementwiseFunctor::Params elementwise; - TensorRefD ref_C; - TensorRefD ref_D; - ElementNorm *ptr_Max; - ElementSum *ptr_Sum; - int64_t batch_stride_C; - int64_t batch_stride_D; - int64_t batch_stride_Max; - int64_t batch_stride_Sum; - - // - // Methods - // - Arguments(): - ptr_Max(nullptr), - ptr_Sum(nullptr), - batch_stride_C(0), - batch_stride_D(0), - batch_stride_Max(0), - batch_stride_Sum(0) - { - - } - - Arguments( - typename ElementwiseFunctor::Params elementwise_, - TensorRefD ref_C_, - TensorRefD ref_D_, - ElementNorm *ptr_Max_, - ElementSum *ptr_Sum_, - int64_t batch_stride_C_, - int64_t batch_stride_D_, - int64_t batch_stride_Max_, - int64_t batch_stride_Sum_ - ): - elementwise(elementwise_), - ref_C(ref_C_), - ref_D(ref_D_), - ptr_Max(ptr_Max_), - ptr_Sum(ptr_Sum_), - batch_stride_C(batch_stride_C_), - batch_stride_D(batch_stride_D_), - batch_stride_Max(batch_stride_Max_), - batch_stride_Sum(batch_stride_Sum_) - { - - } - }; - - struct Params { - - typename ElementwiseFunctor::Params elementwise; - typename OutputTileIterator::Params params_C; - typename OutputTileIterator::Params params_D; - typename OutputTileIterator::Element *ptr_C; - typename OutputTileIterator::Element *ptr_D; - ElementNorm *ptr_Max; - ElementSum *ptr_Sum; - int64_t batch_stride_C; - int64_t batch_stride_D; - int64_t batch_stride_Max; - int64_t batch_stride_Sum; - - // - // Methods - // - CUTLASS_HOST_DEVICE - Params(): - ptr_D(nullptr), - ptr_Max(nullptr), - ptr_Sum(nullptr) - { - - } - - CUTLASS_HOST_DEVICE - Params(Arguments const &args): - elementwise(args.elementwise), - params_C(args.ref_C.layout()), - params_D(args.ref_D.layout()), - ptr_C(args.ref_C.data()), - ptr_D(args.ref_D.data()), - ptr_Max(args.ptr_Max), - ptr_Sum(args.ptr_Sum), - batch_stride_C(args.batch_stride_C), - batch_stride_D(args.batch_stride_D), - batch_stride_Max(args.batch_stride_Max), - batch_stride_Sum(args.batch_stride_Sum) - { - - } - }; - - /// Shared storage - struct SharedStorage { - - }; - -private: - - Params const & params_; - SharedStorage & shared_storage_; - MatrixCoord extent_; - ElementwiseFunctor elementwise_; - - OutputTileIterator iterator_C_; - OutputTileIterator iterator_D_; - typename OutputTileIterator::Fragment fragment_C_; - typename OutputTileIterator::Fragment fragment_D_; - - ElementAccumulator alpha_; - ElementAccumulator beta_; - - ElementSoftmaxCompute accum_max_; - int threadblock_row_; - -public: - - CUTLASS_DEVICE - EpilogueVisitorBiasMax( - Params const ¶ms, ///< Parameters routed to the epilogue - SharedStorage &shared_storage, ///< Shared storage needed by the functors here - MatrixCoord const &problem_size, ///< Problem size of the output - int thread_idx, ///< Thread index within the threadblock - int warp_idx, ///< Warp index within the threadblock - int lane_idx, ///< Lane index within the warp - MatrixCoord const &threadblock_offset = MatrixCoord(0, 0) - ): - params_(params), - shared_storage_(shared_storage), - extent_(problem_size), - elementwise_(params.elementwise), - iterator_C_(params.params_C, params.ptr_C, problem_size, thread_idx, threadblock_offset), - iterator_D_(params.params_D, params.ptr_D, problem_size, thread_idx, threadblock_offset), - threadblock_row_(threadblock_offset.row()) - { - alpha_ = (params.elementwise.alpha_ptr ? *params.elementwise.alpha_ptr : params.elementwise.alpha); - beta_ = (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr : params.elementwise.beta); - - if (beta_ == ElementAccumulator()) { - iterator_C_.clear_mask(); - } - } - - /// Helper to indicate split-K behavior - CUTLASS_DEVICE - void set_k_partition( - int split_k_index, ///< Index of this threadblock within split-K partitioned scheme - int split_k_slices) { ///< Total number of split-K slices - - } - - /// Called to set the batch index - CUTLASS_DEVICE - void set_batch_index(int batch_idx) { - iterator_C_.add_pointer_offset(batch_idx * params_.batch_stride_C); - iterator_D_.add_pointer_offset(batch_idx * params_.batch_stride_D); - } - - /// Called at the start of the epilogue just before iterating over accumulator slices - CUTLASS_DEVICE - void begin_epilogue() { - - } - - /// Called at the start of one step before starting accumulator exchange - CUTLASS_DEVICE - void begin_step(int step_idx) { - fragment_D_.clear(); - fragment_C_.clear(); - - if (elementwise_.kScale != cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) { - iterator_C_.load(fragment_C_); - ++iterator_C_; - } - - } - - /// Called at the start of a row - CUTLASS_DEVICE - void begin_row(int row_idx) { - - } - - /// Called after accumulators have been exchanged for each accumulator vector - CUTLASS_DEVICE - void visit( - int row_idx, - int column_idx, - int frag_idx, - AccumulatorFragment const &accum) { - - using Mul = cutlass::multiplies; - using Minus = cutlass::minus; - using Exp = cutlass::fast_exp_op; - - Minus minus; - Exp exponential; - - SoftmaxFragment result; - - using ConvertSumOutput = cutlass::NumericConverter; - using ConvertNormOutput = cutlass::NumericConverter; - - ConvertSumOutput convert_sum_output; - ConvertNormOutput convert_norm_output; - - NumericArrayConverter source_converter; - OutputVector &source_vector = reinterpret_cast(&fragment_C_)[frag_idx]; - - if (elementwise_.kScale == cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) { - result = source_converter(elementwise_(accum)); - }else{ - result = source_converter(elementwise_(accum, source_vector)); - } - - MatrixCoord thread_offset = - iterator_D_.thread_start() + - OutputTileIterator::ThreadMap::iteration_offset(frag_idx); - - int thread_in_row = OutputTileIterator::ThreadMap::Detail::RowArrangement::Detail::kShapeWidth; - int half_thread_in_row = (thread_in_row >> 1); - - bool column_guard = (thread_offset.column() < extent_.column()); - - // Compute the maximum within one row - if (!column_idx) { - // This is the first fragment in a new row - if (column_guard) { - accum_max_ = maximum_accumulator_(result); - } - } - else { - // This is an additional fragment in the same row - if (column_guard) { - accum_max_ = maximum_accumulator_(result, accum_max_); - } - } - - CUTLASS_PRAGMA_UNROLL - for (int i = half_thread_in_row; i > 0; i >>= 1) { - ElementSoftmaxCompute tmp = __shfl_xor_sync(0xFFFFFFFF, accum_max_, i); - accum_max_ = fast_max(accum_max_, tmp); - } - - SoftmaxFragment sum_frag = exponential(minus(result, accum_max_)); - - ElementSoftmaxCompute reduction_sum = sum_accumulator_(sum_frag); - - CUTLASS_PRAGMA_UNROLL - for (int i = half_thread_in_row; i > 0; i >>= 1) { - ElementSoftmaxCompute tmp = __shfl_xor_sync(0xFFFFFFFF, reduction_sum, i); - reduction_sum += tmp; - } - - bool is_write_thread = (thread_offset.row() < extent_.row() && (threadIdx.x % thread_in_row) == 0); - ElementNorm *curr_ptr_max = params_.ptr_Max + thread_offset.row() + blockIdx.y * extent_.row(); - ElementSum *curr_ptr_sum = params_.ptr_Sum + thread_offset.row() + blockIdx.y * extent_.row(); - - arch::global_store( - convert_norm_output(accum_max_), - (void *)curr_ptr_max, - is_write_thread); - - arch::global_store( - convert_sum_output(reduction_sum), - (void *)curr_ptr_sum, - is_write_thread); - - clear_accum_max_(); - - // Convert to the output - NumericArrayConverter output_converter; - OutputVector &output = reinterpret_cast(&fragment_D_)[frag_idx]; - output = output_converter(result); - } - - /// Called at the start of a row - CUTLASS_DEVICE - void end_row(int row_idx) { - - } - - /// Called after all accumulator elements have been visited - CUTLASS_DEVICE - void end_step(int step_idx) { - - iterator_D_.store(fragment_D_); - ++iterator_D_; - } - - /// Called after all steps have been completed - CUTLASS_DEVICE - void end_epilogue() { - - } - -private: - - CUTLASS_DEVICE - void clear_accum_max_() { - - uint32_t float_max_bits = 0xff7fffff; // -FLT_MAX - float min_float = reinterpret_cast(float_max_bits); - accum_max_ = ElementSoftmaxCompute(min_float); - } - - CUTLASS_DEVICE - ElementSoftmaxCompute sum_accumulator_(SoftmaxFragment const &accum) { - ElementSoftmaxCompute sum_ = ElementSoftmaxCompute(0); - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < SoftmaxFragment::kElements; ++i) { - sum_ += ElementSoftmaxCompute(accum[i]); - } - - return sum_; - } - - CUTLASS_DEVICE - ElementSoftmaxCompute maximum_accumulator_(SoftmaxFragment const &accum) { - ElementSoftmaxCompute max_ = accum[0]; - - CUTLASS_PRAGMA_UNROLL - for (int i = 1; i < SoftmaxFragment::kElements; ++i) { - max_ = fast_max(max_, ElementSoftmaxCompute(accum[i])); - } - - return max_; - } - - CUTLASS_DEVICE - ElementSoftmaxCompute maximum_accumulator_(SoftmaxFragment const &accum, ElementSoftmaxCompute max_) { - - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < SoftmaxFragment::kElements; ++i) { - max_ = fast_max(max_, ElementSoftmaxCompute(accum[i])); + access_d += ApplyShape::kColumn; + access_soft += ApplyShape::kColumn; + idx_n += ApplyShape::kColumn * kAlignment; } - - return max_; } }; @@ -846,10 +299,19 @@ template < typename LayoutB_, typename ElementC_, typename ElementCompute_, + typename OperatorClass_, + typename ArchTag_, + typename ThreadblockShape_, + typename WarpShape_, + typename InstructionShape_, typename EpilogueFunctorOp_, + int kStages_, + typename ApplyShape_ = MatrixShape<1, 1024>, + int AlignmentA_ = 128 / cutlass::sizeof_bits::value, + int AlignmentB_ = 128 / cutlass::sizeof_bits::value, + int AlignmentSoftmax_ = 128 / cutlass::sizeof_bits::value, typename ElementNorm_ = float, typename ElementSum_ = float, - int Alignment = 128 / cutlass::sizeof_bits::value, typename ElementSoftmax_ = ElementC_ > class GemmSoftmax { @@ -872,11 +334,11 @@ class GemmSoftmax { using LayoutA = LayoutA_; using LayoutB = LayoutB_; - static int const kAlignment = Alignment; - using EpilogueFunctorOp = EpilogueFunctorOp_; using ElementNorm = ElementNorm_; + using ApplyShape = ApplyShape_; + // These are mandatory layouts. using LayoutC = cutlass::layout::RowMajor; using LayoutN = cutlass::layout::RowMajor; @@ -890,13 +352,17 @@ class GemmSoftmax { using TensorRefSum = TensorRef; using TensorRefSoft = TensorRef; - using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; - using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; - using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; - using OperatorClass = cutlass::arch::OpClassTensorOp; - using ArchTag = cutlass::arch::Sm80; - static int const kStages = 3; + static int const kStages = kStages_; + static int const AlignmentA = AlignmentA_; + static int const AlignmentB = AlignmentB_; + static int const AlignmentSoftmax = AlignmentSoftmax_; using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle; @@ -906,10 +372,10 @@ class GemmSoftmax { using DefaultGemmKernel = typename cutlass::gemm::kernel::DefaultGemm< ElementA, LayoutA, - kAlignment, + AlignmentA, ElementB, LayoutB, - kAlignment, + AlignmentB, ElementC, LayoutC, ElementCompute, @@ -930,7 +396,7 @@ class GemmSoftmax { /////////////////////////////////////////////////////////////////////////////////////////////// // Epilogue visitor - using EpilogueVisitor = kernel::EpilogueVisitorBiasMax< + using EpilogueVisitor = typename cutlass::epilogue::threadblock::EpilogueVisitorSoftmax< ThreadblockShape, DefaultGemmKernel::kThreadCount, typename DefaultGemmKernel::Epilogue::OutputTileIterator, @@ -961,13 +427,11 @@ class GemmSoftmax { ElementSum, ElementSoft, ElementSoftmaxCompute, - kAlignment, - MatrixShape< - 1, 1024 - > + AlignmentSoftmax, + ApplyShape >; - using ApplyFinalReductionKernel = kernel::ApplyFinalReduction< + using ApplyFinalReductionKernel = cutlass::reduction::kernel::ApplySoftmaxFinalReduction< ElementNorm, ElementSum, ElementSoftmaxCompute, @@ -983,6 +447,7 @@ class GemmSoftmax { typename SoftmaxApplyKernel::Arguments softmax; typename ApplyFinalReductionKernel::Arguments reduction; cutlass::gemm::GemmCoord extend; + // // Methods // @@ -1013,14 +478,14 @@ class GemmSoftmax { batch_count_, ref_A_, ref_B_, + ref_C_, + ref_D_, + ref_N_.data(), + ref_S_.data(), batch_stride_A_, batch_stride_B_, typename EpilogueVisitor::Arguments( linear_scaling, - ref_C_, - ref_D_, - ref_N_.data(), - ref_S_.data(), batch_stride_C_, batch_stride_D_, batch_stride_Max_, @@ -1028,10 +493,9 @@ class GemmSoftmax { ) ), reduction( - MatrixCoord(problem_size.m(), problem_size.n()), - batch_count_, - ref_N_, - ref_S_, + problem_size, + ref_N_.data(), + ref_S_.data(), batch_stride_Max_, batch_stride_Sum_ ), @@ -1114,9 +578,21 @@ class GemmSoftmax { int gemm_smem_size = int(sizeof(typename GemmKernel::SharedStorage)); + cudaError_t result; + + if (gemm_smem_size >= (48 << 10)) { + result = cudaFuncSetAttribute(cutlass::Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + gemm_smem_size); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + cutlass::Kernel<<>>(params_.gemm); - cudaError_t result = cudaGetLastError(); + result = cudaGetLastError(); if (result != cudaSuccess) { return cutlass::Status::kErrorInternal; @@ -1127,42 +603,38 @@ class GemmSoftmax { // Launch the ApplyFinalReductionKernel // - int threadblock_num_in_column = (params_.extend.column() + ThreadblockShape::kN - 1) / ThreadblockShape::kN; - - if (threadblock_num_in_column > 1) { - int thread_per_block = 128; - int block_per_row = (params_.extend.row() + thread_per_block - 1) / thread_per_block; - if (block_per_row < 4) { - thread_per_block = 32; - block_per_row = (params_.extend.row() + thread_per_block - 1) / thread_per_block; - } + int thread_per_block = 128; + int block_per_row = (params_.extend.row() + thread_per_block - 1) / thread_per_block; + if (block_per_row < 4) { + thread_per_block = 32; + block_per_row = (params_.extend.row() + thread_per_block - 1) / thread_per_block; + } - dim3 final_reduction_grid(block_per_row); - dim3 final_reduction_block(thread_per_block); + dim3 final_reduction_grid(block_per_row, 1, params_.softmax.args.batch_count); + dim3 final_reduction_block(thread_per_block); - Kernel<<< - final_reduction_grid, final_reduction_block, sizeof(typename ApplyFinalReductionKernel::SharedStorage), stream - >>>(params_.reduction); + Kernel<<< + final_reduction_grid, final_reduction_block, sizeof(typename ApplyFinalReductionKernel::SharedStorage), stream + >>>(params_.reduction); - result = cudaGetLastError(); + result = cudaGetLastError(); - if (result != cudaSuccess) { - return cutlass::Status::kErrorInternal; - } + if (result != cudaSuccess) { + return cutlass::Status::kErrorInternal; } // // Launch the SoftmaxApplyKernel // - dim3 apply_block(SoftmaxApplyKernel::Shape::kColumn, SoftmaxApplyKernel::Shape::kRow); + dim3 apply_block(SoftmaxApplyKernel::ApplyShape::kColumn, SoftmaxApplyKernel::ApplyShape::kRow); - int cta_rows = SoftmaxApplyKernel::Shape::kRow; - int cta_columns = SoftmaxApplyKernel::Shape::kColumn * SoftmaxApplyKernel::kAlignment; + int threadblock_rows = SoftmaxApplyKernel::ApplyShape::kRow; + int threadblock_columns = SoftmaxApplyKernel::ApplyShape::kColumn * SoftmaxApplyKernel::kAlignment; dim3 apply_grid( - (params_.softmax.args.extent.row() + cta_rows - 1) / cta_rows, - (params_.softmax.args.extent.column() + cta_columns - 1) / cta_columns, + (params_.softmax.args.extent.row() + threadblock_rows - 1) / threadblock_rows, + (params_.softmax.args.extent.column() + threadblock_columns - 1) / threadblock_columns, params_.softmax.args.batch_count); Kernel<<< diff --git a/examples/36_gather_scatter_fusion/CMakeLists.txt b/examples/36_gather_scatter_fusion/CMakeLists.txt index 28edd47868..b54ea9ff81 100644 --- a/examples/36_gather_scatter_fusion/CMakeLists.txt +++ b/examples/36_gather_scatter_fusion/CMakeLists.txt @@ -1,4 +1,4 @@ -# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without diff --git a/examples/36_gather_scatter_fusion/gather_scatter_fusion.cu b/examples/36_gather_scatter_fusion/gather_scatter_fusion.cu index f8fbcc33c3..55852730c2 100644 --- a/examples/36_gather_scatter_fusion/gather_scatter_fusion.cu +++ b/examples/36_gather_scatter_fusion/gather_scatter_fusion.cu @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -40,18 +40,17 @@ // for (int j = 0; j < options.index_size; ++j) { // int b_c_d_col = tensor_indices.at({j, 0}); // -// for (int k = 0; k < problem_size.k(); ++k) { +// for (int k = 0; k < options.index_size; ++k) { // tensor_d_ref.at({i, b_c_d_col}) += // alpha * tensor_a.at({i, k}) * tensor_b.at({k, b_c_d_col}); // } // } -// } // // Note that the index vector contains unique random integers with max to be N - 1 // // The gather/scatter operation works best when we can still keep the biggest // alignment. For example, when the matrix is row major, we select rows. When -// the matrix is column major, we selct columns. +// the matrix is column major, we select columns. // // Not all the combination of gather and scatter are legal. For example, if A is // row major and C/D is column major, we cannot gather A and scatter C/D at the @@ -60,11 +59,11 @@ // Also, we don't check the index value is legal and index array point is valid // for the sake of the performance. -#include -#include -#include -#include -#include +#include +#include +#include +#include +#include #include #include @@ -174,8 +173,8 @@ struct Options { /// Compute performance in GFLOP/s double gflops(double runtime_s) const { - // Number of real-valued multiply-adds - int64_t fmas = problem_size.product(); + // Number of real-valued multiply-adds + int64_t fmas = problem_size.m() * int64_t(index_size) * problem_size.k(); // Two flops per multiply-add return 2.0 * double(fmas) / double(1.0e9) / runtime_s; @@ -188,8 +187,8 @@ struct Options { // elements in input matrices. using ElementAccumulator = float; // <- data type of accumulator using ElementComputeEpilogue = ElementAccumulator; // <- data type of epilogue operations -using ElementInputA = cutlass::half_t;; // <- data type of elements in input matrix A -using ElementInputB = cutlass::half_t;; // <- data type of elements in input matrix B +using ElementInputA = cutlass::half_t; // <- data type of elements in input matrix A +using ElementInputB = cutlass::half_t; // <- data type of elements in input matrix B using ElementOutput = float; // <- data type of elements in output matrix D // The code section below describes matrix layout of input and output matrices. @@ -216,7 +215,7 @@ using ShapeMMAOp = cutlass::gemm::GemmShape<16, 8, 16>; // <- MMA Op tile M = 8 // 16, 8, 16 -> Ampere // This code section describes how threadblocks are scheduled on GPU -using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // <- ?? +using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; // Define the epilogue operation as LinearCombination. This is approximately equal to // @@ -253,11 +252,11 @@ using Gemm = cutlass::gemm::device::GemmUniversal; @@ -317,7 +316,11 @@ int run(Options &options) { // <- Fill tensor_b_indices on host with unique random integers std::vector to_fill(problem_size.n()) ; // vector with ints. std::iota (std::begin(to_fill), std::end(to_fill), 0); // Fill with 0, 1, ...., problem_size.n() - std::random_shuffle(to_fill.begin(), to_fill.end()); + { // std::random_shuffle was deprecated in C++14 and removed in C++17 + std::random_device make_seed; + std::mt19937 source_of_randomness(make_seed()); + std::shuffle(to_fill.begin(), to_fill.end(), source_of_randomness); + } memcpy(tensor_indices.host_data(), to_fill.data(), options.index_size * sizeof(int)); // Copy data from host to GPU @@ -346,14 +349,14 @@ int run(Options &options) { tensor_c.device_data(), // <- reference to matrix C on device tensor_d_scattered.device_data(), // <- reference to matrix D on device tensor_a.layout().capacity(problem_size.mk()), - tensor_b.layout().capacity(cutlass::make_Coord(options.index_size, problem_size.n())), + tensor_b.layout().capacity(cutlass::make_Coord(options.index_size, problem_size.k())), tensor_c.layout().capacity(problem_size.mn()), tensor_d_scattered.layout().capacity(problem_size.mn()), tensor_a.layout().stride(), tensor_b.layout().stride(), tensor_c.layout().stride(), tensor_d_scattered.layout().stride(), - nullptr, // <- pointer to index vector to gather A on device + nullptr, // <- pointer to index vector to gather A on device tensor_indices.device_data(), // <- pointer to index vector to gather B on device tensor_indices.device_data()}; // <- pointer to index vector to scatter D on device @@ -392,7 +395,7 @@ int run(Options &options) { tensor_d_ref.at({i, b_c_d_col}) += alpha * tensor_a.at({i, k}) * tensor_b.at({k, b_c_d_col}); } - + tensor_d_ref.at({i, b_c_d_col}) += (beta * tensor_c.at({i, b_c_d_col})); } } @@ -515,7 +518,7 @@ int main(int argc, const char ** argv) { cudaDeviceProp props; CUDA_CHECK(cudaGetDeviceProperties(&props, 0)); - if (!(props.major > 8 || (props.major == 8 && props.minor >= 0))) { + if (!(props.major >= 8)) { std::cerr << "Ampere Tensor Ops must be run on a machine with compute capability at least 80." << std::endl; notSupported = true; diff --git a/examples/37_gemm_layernorm_gemm_fusion/CMakeLists.txt b/examples/37_gemm_layernorm_gemm_fusion/CMakeLists.txt new file mode 100644 index 0000000000..334ec381a1 --- /dev/null +++ b/examples/37_gemm_layernorm_gemm_fusion/CMakeLists.txt @@ -0,0 +1,36 @@ + +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + + +cutlass_example_add_executable( + 37_gemm_layernorm_gemm_fusion + gemm_layernorm.cu + ) + diff --git a/examples/37_gemm_layernorm_gemm_fusion/gemm_layernorm.cu b/examples/37_gemm_layernorm_gemm_fusion/gemm_layernorm.cu new file mode 100644 index 0000000000..b5a0a1dcb7 --- /dev/null +++ b/examples/37_gemm_layernorm_gemm_fusion/gemm_layernorm.cu @@ -0,0 +1,937 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief CUTLASS Layernorm Example. + + This workload provides a layer normalization example using a one-pass, square-sum-based + variance calculation. Specifically, we fuse the reduction operation to find + local mean and local square sum mean in the epilogue of 1st GEMM. After a light + full reduction kernel, the mean / variance values are readily calculated for element-wise + operations which are fused into the 2nd GEMM. + + As stated in https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance#Computing_shifted_data, + the square-sum based one-pass implementation may raise concerns on numerical stability issues. + That being said, though this fully fused layernorm example almost perfectly hides all the memory cost to + access the intermediate matrix for layernorm computation, the numerical issue might hinder a persuasive + usage in real-world scenarios. If that is the case, a user may turn to the stand-alone CUTLASS layernorm + example in tools/util/include/cutlass/util/device_layernorm.h + + Examples: + + # Run a CUTLASS layernorm example with default setup , + # using the language of the transformer model as an example, + (Column Major output matrix, hidden dimension = 768, valid word number = 4096, intermediate_scale = 4) + $ ./examples/37_gemm_layernorm_gemm_fusion/37_gemm_layernorm_gemm_fusion + + # Run an attention example with hidden dimension = 512 + $ ./examples/37_gemm_layernorm_gemm_fusion/37_gemm_layernorm_gemm_fusion --hidden_dim=512 + +*/ + +#include +#include +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/arch/memory.h" +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/gemm/device/gemm_complex.h" +#include "cutlass/epilogue/thread/scale_type.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/host/gemm_complex.h" +#include "cutlass/util/reference/host/tensor_reduce.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/error_metrics.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/fast_math.h" +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "gemm_with_layernorm.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +enum class Disposition { + kPassed, + kIncorrect, + kNotVerified +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +template +struct Options { + + using LayoutOutput = LayoutOutput_; + + static bool const kIsColumnMajorOutput = cutlass::platform::is_same::value; + + bool help; + cutlass::gemm::GemmCoord problem_size0; + cutlass::gemm::GemmCoord problem_size1; + int hidden_dim; + int valid_word_num; + int intermediate_scale; + int iterations; + unsigned seed; + float alpha; + float beta; + bool verification_enabled; + double tolerance; + + Options(): + help(false), + iterations(20), + seed(2022), + hidden_dim(768), + valid_word_num(4096), + intermediate_scale(4), + alpha(1), + beta(0), + verification_enabled(true), + tolerance(0.01), + problem_size1(problem_size0.m() * 4, problem_size0.n(), problem_size0.m()) + { } + + bool valid() { + + return true; + } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + } + + cmd.get_cmd_line_argument("hidden_dim", hidden_dim, 768); + cmd.get_cmd_line_argument("valid_word_num", valid_word_num, 4096); + cmd.get_cmd_line_argument("iterations", iterations); + cmd.get_cmd_line_argument("verify", verification_enabled); + cmd.get_cmd_line_argument("seed", seed); + cmd.get_cmd_line_argument("tolerance", tolerance); + + if (kIsColumnMajorOutput) { + // column major output setup + problem_size0.m() = hidden_dim; + problem_size0.n() = valid_word_num; + problem_size0.k() = hidden_dim; + + problem_size1.m() = hidden_dim * intermediate_scale; + problem_size1.n() = valid_word_num; + problem_size1.k() = hidden_dim; + }else{ + // row major output setup + problem_size0.m() = valid_word_num; + problem_size0.n() = hidden_dim; + problem_size0.k() = hidden_dim; + + problem_size1.m() = valid_word_num; + problem_size1.n() = hidden_dim * intermediate_scale; + problem_size1.k() = hidden_dim; + } + + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "37_gemm_layernorm_gemm_fusion example\n\n" + << " This example uses the CUTLASS Library to compute GEMM + Layernorm for arbitrary problem sizes.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement.\n\n" + << " --hidden_dim= Hidden dimension\n" + << " --valid_word_num= Valid word number\n" + << " --seed= Random number seed (1*)\n\n" + << " --iterations= Number of profiling iterations to perform (0 to disable profiling).\n\n" + << " --verify= If true, performs reference calculation.\n\n" + << " --tolerance Error tolerance\n" + ; + + out << "\n\nExamples:\n\n" + << "$ ./examples/37_gemm_layernorm_gemm_fusion/37_gemm_layernorm_gemm_fusion \\\n" + << " --hidden_dim=768 --valid_word_num=1024 \n\n"; + + return out; + } + + /// Returns true if the environment and Toolkit support this + bool supported(bool verbose = true) const { + + // Ampere Tensor Core operations exposed with mma.sync and ldmatrix are first available + // in CUDA 11.0. + // + // CUTLASS must be compiled with CUDA 11.0 Toolkit to run these examples. + if (!(__CUDACC_VER_MAJOR__ >= 11)) { + if (verbose) { + std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl; + } + return false; + } + + cudaDeviceProp props; + + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + if (verbose) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + } + return false; + } + + if (!((props.major * 10 + props.minor) >= 80)) { + if (verbose) { + std::cerr << "Ampere Tensor Core operations must be run on a machine with compute capability at least 80." + << std::endl; + } + return false; + } + + // + // CUTLASS attempts to load 128b vectors of cutlass::half_t (F16) elements. Consequently, + // all pointers, strides, and tensor extents must be divisible by 8 elements. + // + int const kAlignment = 8; + + if ((problem_size0.m() % kAlignment) || + (problem_size0.n() % kAlignment) || + (problem_size0.k() % kAlignment)) { + if (verbose) { + std::cerr << "Misaligned input in 1st GEMM." << std::endl; + } + // misaligned tensors for Gemm1 + return false; + } + + if ((problem_size1.m() % kAlignment) || + (problem_size1.n() % kAlignment) || + (problem_size1.k() % kAlignment)) { + if (verbose) { + std::cerr << "Misaligned input in 2nd GEMM." << std::endl; + } + // misaligned tensors for Gemm2 + return false; + } + + return true; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + typename LayoutOutput_> +struct Testbed { + + // + // Type definitions + // + + // User-defined data types + using ElementInputA0 = cutlass::half_t; + using ElementInputB0 = cutlass::half_t; + using ElementOutput = cutlass::half_t; + using ElementCompute = cutlass::half_t; + + using LayoutInputA0 = cutlass::layout::RowMajor; + using LayoutInputB0 = cutlass::layout::ColumnMajor; + using LayoutOutput = LayoutOutput_; + + static bool const kIsColumnMajorOutput = cutlass::platform::is_same::value; + // turn of shifted K by default + static bool const kIsShiftedVariance = false; + + /// Linear scaling operator + using EpilogueFunctorOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementCompute, + ElementCompute + >; + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + static int const kStages0 = 3; + static int const kStages1 = 4; + + using GemmLayernorm = cutlass::GemmLayernorm< + ElementInputA0, + LayoutInputA0, + ElementInputB0, + LayoutInputB0, + ElementOutput, + LayoutOutput, + ElementCompute, + EpilogueFunctorOp, + ThreadblockShape, + WarpShape, + InstructionShape, + kStages0, + kStages1, + kIsShiftedVariance + >; + + using ElementInputA1 = typename GemmLayernorm::ElementInputA1; + using ElementOutputC1 = typename GemmLayernorm::ElementOutputC1; + using ElementInputScaleBias = typename GemmLayernorm::ElementInputScaleBias; + using ElementLayernormCompute = typename GemmLayernorm::ElementLayernormCompute; + + using LayoutInputA1 = typename GemmLayernorm::LayoutInputA1; + using LayoutOutputC0 = typename GemmLayernorm::LayoutOutputC0; + using LayoutOutputC1 = typename GemmLayernorm::LayoutOutputC1; + using LayoutInputScaleBias = typename GemmLayernorm::LayoutInputScaleBias; + + // + // Data members + // + + Options const &options; + + cutlass::HostTensor tensor_A0; + cutlass::HostTensor tensor_B0; + cutlass::HostTensor tensor_C0; + cutlass::HostTensor tensor_A1; + cutlass::HostTensor tensor_C1; + + cutlass::HostTensor reference_C0; + cutlass::HostTensor reference_C1; + + cutlass::HostTensor tensor_Variance; + cutlass::HostTensor tensor_Mean; + cutlass::HostTensor tensor_Beta; + cutlass::HostTensor tensor_Gamma; + + cutlass::HostTensor reference_Mean; + cutlass::HostTensor reference_Variance; + + // shifted K tensor to better ensure the numerical stability + // According to https://en.wikipedia.org/wiki/Algorithms_for_calculating_variance + // the closer shifted K to the actual mean, the better numerical stability we'll observe + cutlass::HostTensor tensor_Shifted_K; + + // + // Methods + // + + Testbed( + Options const &options_ + ): + options(options_) + { + + tensor_A0.reset({options.problem_size0.m(), options.problem_size0.k()}); + tensor_B0.reset({options.problem_size0.k(), options.problem_size0.n()}); + + tensor_C0.reset({options.problem_size0.m(), options.problem_size0.n()}); + + tensor_A1.reset({options.problem_size1.m(), options.problem_size1.k()}); + tensor_C1.reset({options.problem_size1.m(), options.problem_size1.n()}); + + reference_C0.reset({options.problem_size0.m(), options.problem_size0.n()}); + reference_C1.reset({options.problem_size1.m(), options.problem_size1.n()}); + + int leading_dim_0 = kIsColumnMajorOutput ? options.problem_size0.n() : options.problem_size0.m(); + int leading_dim_1 = kIsColumnMajorOutput ? options.problem_size0.m() : options.problem_size0.n(); + + int block_num = (leading_dim_1 + GemmLayernorm::ThreadblockShape::kM - 1) / GemmLayernorm::ThreadblockShape::kM; + + tensor_Variance.reset({block_num, leading_dim_0}); + tensor_Mean.reset({block_num, leading_dim_0}); + tensor_Shifted_K.reset({1, leading_dim_0}); + + tensor_Beta.reset({1, leading_dim_1}); + tensor_Gamma.reset({1, leading_dim_1}); + + reference_Mean.reset({1, leading_dim_0}, false); + reference_Variance.reset({1, leading_dim_0}, false); + + } + + /// Run + Disposition run() { + + Disposition disposition = Disposition::kNotVerified; + + // + // Initialize the workspace + // + + initialize(); + + // + // Launch device kernel + // + cutlass::Status status = cutlass::Status::kSuccess; + + status = execute_device_kernel(); + + if (status != cutlass::Status::kSuccess) { + std::cerr << "Device execution failed." << std::endl; + return disposition; + } + + cudaError_t result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + std::cerr << "Device synchronize failed with error " + << cudaGetErrorString(result) << std::endl; + return disposition; + } + + // + // Compute the reference + // + compute_reference(); + + // + // Verify + // + + if (options.verification_enabled) { + + bool passed = verify(); + + if (passed) { + disposition = Disposition::kPassed; + } + else { + disposition = Disposition::kIncorrect; + } + } + + // + // Profiling + // + if (options.iterations) { + profile(); + } + + return disposition; + } + + /// Random initialization + void initialize() { + + cutlass::reference::host::TensorFillRandomUniform( + tensor_A0.host_view(), + options.seed, + ElementInputA0(4), + ElementInputA0(-4), + 0 + ); + + cutlass::reference::host::TensorFillRandomUniform( + tensor_B0.host_view(), + options.seed + 1, + ElementInputB0(4), + ElementInputB0(-4), + 0 + ); + + cutlass::reference::host::TensorFillRandomUniform( + tensor_A1.host_view(), + options.seed + 2, + ElementInputA1(4), + ElementInputA1(-4), + 0 + ); + + cutlass::reference::host::TensorFillRandomUniform( + tensor_Beta.host_view(), + options.seed + 3, + ElementInputScaleBias(4), + ElementInputScaleBias(-4), + 0 + ); + + cutlass::reference::host::TensorFillRandomUniform( + tensor_Gamma.host_view(), + options.seed + 4, + ElementInputScaleBias(4), + ElementInputScaleBias(-4), + 0 + ); + + cutlass::reference::host::TensorFillRandomUniform( + tensor_Shifted_K.host_view(), + options.seed + 5, + ElementOutput(4), + ElementOutput(-5), + 0 + ); + + tensor_A0.sync_device(); + tensor_B0.sync_device(); + tensor_A1.sync_device(); + tensor_Beta.sync_device(); + tensor_Gamma.sync_device(); + + } + + + + cutlass::Status execute_device_kernel() { + + cutlass::Status status = cutlass::Status::kSuccess; + + // + // Setup arguments + // + + typename GemmLayernorm::Arguments args( + options.problem_size0, + options.problem_size1, + tensor_A0.device_ref().data(), + tensor_B0.device_ref().data(), + tensor_C0.device_ref().data(), + tensor_C0.device_ref().data(), + tensor_A1.device_ref().data(), + tensor_C1.device_ref().data(), + tensor_A0.device_ref().stride(0), + tensor_B0.device_ref().stride(0), + tensor_C0.device_ref().stride(0), + tensor_C0.device_ref().stride(0), + tensor_A1.device_ref().stride(0), + tensor_C1.device_ref().stride(0), + { + ElementCompute(options.alpha), + ElementCompute(options.beta) + }, + tensor_Variance.device_ref(), + tensor_Mean.device_ref(), + tensor_Gamma.device_ref(), + tensor_Beta.device_ref(), + tensor_Shifted_K.device_ref().data() + ); + + // + // Launch + // + + GemmLayernorm gemm_layernorm; + + // Initialize + status = gemm_layernorm.initialize(args); + if (status != cutlass::Status::kSuccess) { + return status; + } + + // Run + status = gemm_layernorm(); + + return status; + } + + /// Reference calculation + void compute_reference() { + + cutlass::reference::device::Gemm< + ElementInputA0, + LayoutInputA0, + ElementInputB0, + LayoutInputB0, + ElementOutput, + LayoutOutputC0, + ElementCompute, + ElementCompute + > gemm_device0; + + cutlass::reference::device::Gemm< + ElementInputA1, + LayoutInputA1, + ElementOutput, + LayoutOutputC0, + ElementOutputC1, + LayoutOutputC1, + ElementCompute, + ElementCompute + > gemm_device1; + + // Compute 1st GEMM + gemm_device0( + options.problem_size0, + ElementCompute(options.alpha), + tensor_A0.device_ref(), + tensor_B0.device_ref(), + ElementCompute(options.beta), + tensor_C0.device_ref(), + reference_C0.device_ref() + ); + + reference_C0.sync_host(); + + tensor_Mean.sync_host(); + tensor_Variance.sync_host(); + tensor_Gamma.sync_host(); + tensor_Beta.sync_host(); + tensor_Shifted_K.sync_host(); + + // Compute the sum and square sum for verification purpose + if (kIsColumnMajorOutput) { + for (int n = 0; n < options.problem_size0.n(); ++n) { + + ElementLayernormCompute sum = ElementLayernormCompute(0); + ElementLayernormCompute square_sum = ElementLayernormCompute(0); + for (int m = 0; m < options.problem_size0.m(); ++m) { + sum += ElementLayernormCompute(reference_C0.at({m, n})); + square_sum += ElementLayernormCompute(reference_C0.at({m, n})) * ElementLayernormCompute(reference_C0.at({m, n})); + } + + ElementLayernormCompute mean = sum / ElementLayernormCompute(options.problem_size0.m()); + ElementLayernormCompute square_mean = square_sum / ElementLayernormCompute(options.problem_size0.m()); + ElementLayernormCompute variance = cutlass::constants::one() / cutlass::fast_sqrt(square_mean - mean * mean + ElementLayernormCompute(1e-6) ) ; + + mean = -mean * variance; + + reference_Mean.at({0, n}) = ElementInputScaleBias(mean); + reference_Variance.at({0, n}) = ElementInputScaleBias(variance); + } + }else{ + for (int m = 0; m < options.problem_size0.m(); ++m) { + + ElementLayernormCompute sum = ElementLayernormCompute(0); + ElementLayernormCompute square_sum = ElementLayernormCompute(0); + for (int n = 0; n < options.problem_size0.n(); ++n) { + sum += ElementLayernormCompute(reference_C0.at({m, n})) ; + square_sum += ElementLayernormCompute(reference_C0.at({m, n})) * ElementLayernormCompute(reference_C0.at({m, n})) ; + } + + ElementLayernormCompute mean = sum / ElementLayernormCompute(options.problem_size0.n()); + ElementLayernormCompute square_mean = square_sum / ElementLayernormCompute(options.problem_size0.n()); + ElementLayernormCompute variance = cutlass::constants::one() / cutlass::fast_sqrt(square_mean - mean * mean + ElementLayernormCompute(1e-6)) ; + + mean = -mean * variance; + + reference_Mean.at({0, m}) = ElementInputScaleBias(mean); + reference_Variance.at({0, m}) = ElementInputScaleBias(variance); + } + } + + // Element-wise transform for OutputC0 using 1-pass layernorm algo + if (kIsColumnMajorOutput) { + for (int n = 0; n < options.problem_size0.n(); ++n) { + + ElementLayernormCompute sum = ElementLayernormCompute(0); + for (int m = 0; m < options.problem_size0.m(); ++m) { + sum += ElementLayernormCompute(reference_C0.at({m, n})) ; + } + + ElementInputScaleBias mean = ElementInputScaleBias(sum / ElementLayernormCompute(options.problem_size0.m())); + sum = ElementLayernormCompute(0); + for (int m = 0; m < options.problem_size0.m(); ++m) { + sum += ElementLayernormCompute(reference_C0.at({m, n}) - ElementLayernormCompute(mean)) * ElementLayernormCompute(reference_C0.at({m, n}) - ElementLayernormCompute(mean)) ; + } + + ElementLayernormCompute square_mean = sum / ElementLayernormCompute(options.problem_size0.m()); + ElementInputScaleBias variance = ElementInputScaleBias(cutlass::constants::one() + / cutlass::fast_sqrt(square_mean + ElementLayernormCompute(1e-6))) ; + + for (int m = 0; m < options.problem_size0.m(); ++m) { + reference_C0.at({m, n}) = + ElementOutput( ( (ElementInputScaleBias(reference_C0.at({m, n})) - mean) * variance ) + * tensor_Gamma.at({0, m}) + tensor_Beta.at({0, m})); + + } + + } + }else{ + + for (int m = 0; m < options.problem_size0.m(); ++m) { + + float sum = float(0); + for (int n = 0; n < options.problem_size0.n(); ++n) { + sum += float(reference_C0.at({m, n})) ; + } + + float mean = sum / float(options.problem_size0.n()); + sum = float(0); + for (int n = 0; n < options.problem_size0.n(); ++n) { + sum += float(reference_C0.at({m, n}) - mean) * float(reference_C0.at({m, n}) - mean) ; + } + + float square_mean = sum / float(options.problem_size0.n()); + float variance = cutlass::constants::one() / cutlass::fast_sqrt(square_mean + ElementLayernormCompute(1e-6)) ; + + for (int n = 0; n < options.problem_size0.n(); ++n) { + reference_C0.at({m, n}) = + ElementOutput( ( (float(reference_C0.at({m, n})) - mean) * variance ) + * float(tensor_Gamma.at({0, n})) + float(tensor_Beta.at({0, n}))); + + } + + } + + } + + + // Sync host data with device after element-wise transform + reference_C0.sync_device(); + + // Compute 2nd GEMM + gemm_device1( + options.problem_size1, + ElementCompute(options.alpha), + kIsColumnMajorOutput ? tensor_A1.device_ref() : reference_C0.device_ref(), + kIsColumnMajorOutput ? reference_C0.device_ref() :tensor_A1.device_ref(), + ElementCompute(options.beta), + reference_C1.device_ref(), + reference_C1.device_ref() + ); + + } + + /// Emits all tensor values + void emit_results() { + std::cout << "tensor_C1 = \n" << tensor_C1.host_view() << "\n\n"; + std::cout << "Reference C1 = \n" << reference_C1.host_view() << "\n\n"; + std::cout << "Mean = \n" << tensor_Mean.host_view() << "\n\n"; + std::cout << "rsqrt(Variance) = \n" << tensor_Variance.host_view() << "\n\n"; + std::cout << "Reference Mean = \n" << reference_Mean.host_view() << "\n\n"; + std::cout << "Reference rsqrt(Variance) = \n" << reference_Variance.host_view() << "\n\n"; + } + + template + bool verify_tensor(cutlass::HostTensor tensor, \ + cutlass::HostTensor reference, + int leading_dim0, int leading_dim1, bool is_print = false) { + float const kThreshold = float(options.tolerance); + float const kAbsThreshold = 0.5f; + float const kRelativeThreshold = 0.1f; + // Adds a constant bias to avoid being divided by '0' + float const kBias = 1e-5f; + int counter = 0; + for (int m = 0; m < leading_dim0; m++) { + for (int n = 0; n < leading_dim1; ++n) { + float diff = (float)(tensor.at({m, n}) - reference.at({m, n})); + float rel_diff = fabs(diff) / fabs(reference.at({m, n}) + kBias); + if (fabs(diff) > kAbsThreshold && rel_diff > kRelativeThreshold) { + counter++; + } + } + } + + float err_rate = float(counter) / (float(leading_dim0) * float(leading_dim1)); + return (err_rate < kThreshold); + } + + /// Verifies the reference matches + bool verify() { + + tensor_Variance.sync_host(); + tensor_Mean.sync_host(); + tensor_C1.sync_host(); + reference_C1.sync_host(); + + // Verification checks - set any of these to 'true' to override the verification checks. + bool verified_C1 = false; + bool verified_Mean = false; + bool verified_Variance = false; + + // Verify layernorm output + if (!verified_C1) { + verified_C1 = verify_tensor(tensor_C1, reference_C1, options.problem_size1.m(), options.problem_size1.n()); + } + + if (!verified_Variance) { + verified_Variance = verify_tensor(tensor_Variance, reference_Variance, 1, options.problem_size0.n()); + } + + if (!verified_Mean) { + verified_Mean = verify_tensor(tensor_Mean, reference_Mean, 1, options.problem_size0.n()); + } + + if (!verified_C1 || !verified_Mean || !verified_Variance) { + + // emit_results(); + + std::cerr << "Verification check failed for tensor Layernorm" << std::endl; + + // Summarize which checks failed + if (!verified_C1) { + std::cerr << "Verification of O tensor failed\n"; + } + + if (!verified_Mean) { + std::cerr << "Verification of Mean tensor failed\n"; + } + + if (!verified_Variance) { + std::cerr << "Verification of Variance tensor failed\n"; + } + + return false; + } + + return true; + } + + /// Profiles + bool profile() { + + // + // Profile + // + + cutlass::Status status = cutlass::Status::kSuccess; + cudaError_t result; + cudaEvent_t events[2]; + int const kIterations = options.iterations; + + for (cudaEvent_t &evt : events) { + result = cudaEventCreate(&evt); + if (result != cudaSuccess) { + std::cerr << "cudaEventCreate failed with error " << cudaGetErrorString(result) << std::endl; + return false; + } + } + + result = cudaEventRecord(events[0]); + + if (result != cudaSuccess) { + std::cerr << "cudaEventRecord() failed with error " << cudaGetErrorString(result) << std::endl; + return false; + } + + for (int iter = 0; iter < kIterations; ++iter) { + + status = execute_device_kernel(); + + if (status != cutlass::Status::kSuccess) { + std::cerr << "Device execution failed." << std::endl; + return false; + } + } + + result = cudaEventRecord(events[1]); + + if (result != cudaSuccess) { + std::cerr << "cudaEventRecord() failed with error " << cudaGetErrorString(result) << std::endl; + return false; + } + + result = cudaDeviceSynchronize(); + + if (result != cudaSuccess) { + std::cerr << "cudaDeviceSynchronize() failed with error " << cudaGetErrorString(result) << std::endl; + return false; + } + + float elapsed_ms = 0; + result = cudaEventElapsedTime(&elapsed_ms, events[0], events[1]); + + float elapsed_ms_per_iter = elapsed_ms / float(kIterations); + + if (result != cudaSuccess) { + std::cerr << "cudaEventElapsedTime() failed with error " << cudaGetErrorString(result) << std::endl; + return false; + } + + for (cudaEvent_t &evt : events) { + result = cudaEventDestroy(evt); + if (result != cudaSuccess) { + std::cerr << "cudaEventDestroy() failed with error " << cudaGetErrorString(result) << std::endl; + return false; + } + } + + int64_t flops = int64_t(options.problem_size0.m()) * options.problem_size0.n() * options.problem_size0.k() * 2 \ + + int64_t(options.problem_size1.m()) * options.problem_size1.n() * options.problem_size1.k() * 2; + + double gflops_per_second = double(flops) * kIterations / double(elapsed_ms / 1000.0f) / double(1.0e9); + + std::cout << " 1st GEMM: " + << options.problem_size0.m() << "-by-" << options.problem_size0.n() << "-by-" << options.problem_size0.k() << "\n" + << " 2nd GEMM: " + << options.problem_size1.m() << "-by-" << options.problem_size1.n() << "-by-" << options.problem_size1.k() + << std::endl; + + std::cout << " Runtime / iteration: " << elapsed_ms_per_iter << " ms\n" << std::endl; + std::cout << " GFLOPs: " << gflops_per_second << " GFLOPs" << std::endl; + + return true; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, const char **argv) { + + // Define final layout + using LayoutOutput = cutlass::layout::ColumnMajor; + + // Options parsing + Options options; + options.parse(argc, argv); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (!options.supported()) { + return 0; + } + + // Run + Testbed testbed(options); + + Disposition disposition = testbed.run(); + + std::cout << std::endl; + + switch (disposition) { + case Disposition::kPassed: + std::cout << "Passed" << std::endl; + break; + case Disposition::kIncorrect: + std::cout << "Incorrect" << std::endl; + break; + case Disposition::kNotVerified: + std::cout << "Not verified" << std::endl; + break; + } + + return (disposition == Disposition::kPassed ? 0 : -1); +} + + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/37_gemm_layernorm_gemm_fusion/gemm_with_epilogue_visitor.h b/examples/37_gemm_layernorm_gemm_fusion/gemm_with_epilogue_visitor.h new file mode 100644 index 0000000000..666f3cb566 --- /dev/null +++ b/examples/37_gemm_layernorm_gemm_fusion/gemm_with_epilogue_visitor.h @@ -0,0 +1,444 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief GEMM kernel to support the epilogue visitor model + for customized layernorm partial reduction epilogue fusion. + + This source file will likely be moved to `include/cutlass/gemm/kernel/` in the future once + its usage has been stabilized. For now, it is included in this example to demonstrate + some basic output fusion options. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/complex.h" +#include "cutlass/semaphore.h" + +#include "cutlass/trace.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_ ///! Threadblock swizzling function +> +struct GemmWithEpilogueVisitor { +public: + + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueVisitor = typename Epilogue::Visitor; + using ThreadblockSwizzle = ThreadblockSwizzle_; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using TensorRefA = TensorRef; + + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using TensorRefB = TensorRef; + + using ElementC = typename EpilogueVisitor::ElementOutput; + using LayoutC = typename Epilogue::Layout; + + static ComplexTransform const kTransformA = Mma::kTransformA; + static ComplexTransform const kTransformB = Mma::kTransformB; + using Operator = typename Mma::Operator; + + using OperatorClass = typename Mma::Operator::OperatorClass; + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename Mma::Operator::Shape; + using InstructionShape = typename Mma::Policy::Operator::InstructionShape; + using ArchTag = typename Mma::ArchTag; + + static int const kStages = Mma::kStages; + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = EpilogueVisitor::kElementsPerAccess; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + /// Split-K preserves splits that are 128b aligned + static int const kSplitKAlignment = const_max( + 128 / sizeof_bits::value, + 128 / sizeof_bits::value + ); + + // + // Structures + // + + /// Argument structure + struct Arguments { + + // + // Data members + // + + GemmUniversalMode mode; + GemmCoord problem_size; + + TensorRefA ref_A; + TensorRefB ref_B; + + typename EpilogueVisitor::Arguments epilogue_visitor; + + // + // Methods + // + + Arguments(): + mode(GemmUniversalMode::kGemm) + { } + + + /// constructs an arguments structure + Arguments( + GemmUniversalMode mode_, + GemmCoord problem_size_, + TensorRefA ref_A_, + TensorRefB ref_B_, + typename EpilogueVisitor::Arguments epilogue_visitor_ + ): + mode(mode_), + problem_size(problem_size_), + ref_A(ref_A_), + ref_B(ref_B_), + epilogue_visitor(epilogue_visitor_) + { + + } + }; + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params { + + cutlass::gemm::GemmCoord problem_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + int swizzle_log_tile; + + typename Mma::IteratorA::Params params_A; + typename Mma::IteratorB::Params params_B; + + GemmUniversalMode mode; + int gemm_k_size; + + void * ptr_A; + void * ptr_B; + + typename EpilogueVisitor::Params epilogue_visitor; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): + swizzle_log_tile(0), + params_A(0), + params_B(0), + gemm_k_size(0), + mode(cutlass::gemm::GemmUniversalMode::kGemm), + ptr_A(nullptr), + ptr_B(nullptr) + { } + + + Params( + Arguments const &args + ): + problem_size(args.problem_size), + swizzle_log_tile(0), + params_A(args.ref_A.layout()), + params_B(args.ref_B.layout()), + mode(args.mode), + gemm_k_size(args.problem_size.k()), + ptr_A(args.ref_A.data()), + ptr_B(args.ref_B.data()), + epilogue_visitor(args.epilogue_visitor) + { + + ThreadblockSwizzle threadblock_swizzle; + + grid_tiled_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, 1); + + if (args.mode == GemmUniversalMode::kGemm || args.mode == GemmUniversalMode::kGemmSplitKParallel) { + + int const kAlignK = const_max(const_max(128 / sizeof_bits::value, 128 / sizeof_bits::value), 1); + + gemm_k_size = round_up(args.problem_size.k(), kAlignK); + + if (gemm_k_size) { + grid_tiled_shape.k() = ceil_div(args.problem_size.k(), gemm_k_size); + } + } + + swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape); + } + }; + + /// Shared memory storage structure + union SharedStorage { + + typename Mma::SharedStorage main_loop; + + struct { + typename Epilogue::SharedStorage epilogue; + typename EpilogueVisitor::SharedStorage visitor; + } epilogue; + }; + +public: + + // + // Methods + // + + CUTLASS_DEVICE + GemmWithEpilogueVisitor() { } + + /// Determines whether kernel satisfies alignment + static Status can_implement( + cutlass::gemm::GemmCoord const & problem_size) { + + CUTLASS_TRACE_HOST("GemmWithEpilogueVisitor::can_implement()"); + + static int const kAlignmentA = Mma::IteratorA::AccessType::kElements; + static int const kAlignmentB = Mma::IteratorB::AccessType::kElements; + static int const kAlignmentC = Epilogue::OutputTileIterator::kElementsPerAccess; + + bool isAMisaligned = false; + bool isBMisaligned = false; + bool isCMisaligned = false; + + if (platform::is_same::value) { + isAMisaligned = problem_size.k() % kAlignmentA; + } else if (platform::is_same::value) { + isAMisaligned = problem_size.m() % kAlignmentA; + } else if (platform::is_same>::value + || platform::is_same>::value) { + isAMisaligned = problem_size.k() % kAlignmentA; + } + + if (platform::is_same::value) { + isBMisaligned = problem_size.n() % kAlignmentB; + } else if (platform::is_same::value) { + isBMisaligned = problem_size.k() % kAlignmentB; + } else if (platform::is_same>::value + || platform::is_same>::value) { + isBMisaligned = problem_size.k() % kAlignmentB; + } + + if (platform::is_same::value) { + isCMisaligned = problem_size.n() % kAlignmentC; + } else if (platform::is_same::value) { + isCMisaligned = problem_size.m() % kAlignmentC; + } else if (platform::is_same>::value + || platform::is_same>::value) { + isCMisaligned = problem_size.n() % kAlignmentC; + } + + if (isAMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for A operand"); + return Status::kErrorMisalignedOperand; + } + + if (isBMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for B operand"); + return Status::kErrorMisalignedOperand; + } + + if (isCMisaligned) { + CUTLASS_TRACE_HOST(" returning kErrorMisalignedOperand for C operand"); + return Status::kErrorMisalignedOperand; + } + + CUTLASS_TRACE_HOST(" returning kSuccess"); + + return Status::kSuccess; + } + + static Status can_implement(Arguments const &args) { + return can_implement(args.problem_size); + } + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || + params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { + + return; + } + + int offset_k = 0; + int problem_size_k = params.problem_size.k(); + + ElementA *ptr_A = static_cast(params.ptr_A); + ElementB *ptr_B = static_cast(params.ptr_B); + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A{ + threadblock_tile_offset.m() * Mma::Shape::kM, + offset_k, + }; + + cutlass::MatrixCoord tb_offset_B{ + offset_k, + threadblock_tile_offset.n() * Mma::Shape::kN + }; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + params.params_A, + ptr_A, + {params.problem_size.m(), problem_size_k}, + thread_idx, + tb_offset_A); + + typename Mma::IteratorB iterator_B( + params.params_B, + ptr_B, + {problem_size_k, params.problem_size.n()}, + thread_idx, + tb_offset_B); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size_k - offset_k + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma( + gemm_k_iterations, + accumulators, + iterator_A, + iterator_B, + accumulators); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + //assume identity swizzle + MatrixCoord threadblock_offset( + threadblock_tile_offset.m() * Mma::Shape::kM, + threadblock_tile_offset.n() * Mma::Shape::kN + ); + + int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + + // + // Construct the epilogue visitor + // + + EpilogueVisitor epilogue_visitor( + params.epilogue_visitor, + shared_storage.epilogue.visitor, + params.problem_size.mn(), + thread_idx, + warp_idx, + lane_idx, + threadblock_offset); + + if (params.mode == GemmUniversalMode::kGemm) { + // Indicate which position in a serial reduction the output operator is currently updating + epilogue_visitor.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + } + else if (params.mode == GemmUniversalMode::kBatched || params.mode == GemmUniversalMode::kArray) { + epilogue_visitor.set_batch_index(threadblock_tile_offset.k()); + } + + // Construct the epilogue + Epilogue epilogue( + shared_storage.epilogue.epilogue, + thread_idx, + warp_idx, + lane_idx); + + // Execute the epilogue operator to update the destination tensor. + epilogue(epilogue_visitor, accumulators); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/37_gemm_layernorm_gemm_fusion/gemm_with_layernorm.h b/examples/37_gemm_layernorm_gemm_fusion/gemm_with_layernorm.h new file mode 100644 index 0000000000..b33954ecce --- /dev/null +++ b/examples/37_gemm_layernorm_gemm_fusion/gemm_with_layernorm.h @@ -0,0 +1,1066 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief A file contains all functioning classes needed by GemmLayernorm. + + GemmLayernorm example = GEMM0 with partial reduction fused in epilogue (EpilogueVisitorLayerNorm) + + lightweight full reduction kernel (ApplyFinalReduction) + + GEMM1 with elemenwise operations fused in mainloop (GemmLayernormMainloopFusion) + +*/ + +#pragma once + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include +#include +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/arch/memory.h" +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/gemm/device/gemm_layernorm_mainloop_fusion.h" +#include "cutlass/gemm/kernel/gemm_transpose_operands.h" +#include "cutlass/gemm/kernel/default_gemm.h" +#include "cutlass/gemm/kernel/default_gemm_complex.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/epilogue/threadblock/epilogue_with_visitor.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "gemm_with_epilogue_visitor.h" +#include "helper.h" +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementVariance_, + typename ElementMean_, + typename ElementLayernormCompute_, + typename ElementOutput, + typename ThreadblockShape_, + bool IsShiftedVariance_ = false +> +class ApplyFinalReduction { +public: + + using ElementVariance = ElementVariance_; + using ElementMean = ElementMean_; + using ElementLayernormCompute = ElementLayernormCompute_; + using ThreadblockShape = ThreadblockShape_; + + // Pre-processing has ensured the layout equivelent to RowMajor + using Layout = cutlass::layout::RowMajor; + + using TensorVariance = TensorRef; + using TensorMean = TensorRef; + + static bool const kIsShiftedVariance = IsShiftedVariance_; + + // + // Arguments + // + + struct Arguments { + + MatrixCoord extent; ///< Extent of D and Layernorm matrices + TensorVariance ref_Variance; ///< Sum Square or Variance tensor (input / output) + TensorMean ref_Mean; ///< Sum or Mean tensor (input / output) + ElementOutput *ptr_Shifted_K; ///< Shifted K tensor pointer + + // + // Methods + // + Arguments(){ } + + Arguments( + MatrixCoord extent_, + TensorVariance ref_Variance_, + TensorMean ref_Mean_, + ElementOutput *ptr_Shifted_K_ + ): + extent(extent_), + ref_Variance(ref_Variance_), + ref_Mean(ref_Mean_), + ptr_Shifted_K(ptr_Shifted_K_) + { + + } + }; + + struct SharedStorage { + + + }; + + // + // Params struct + // + + struct Params { + Arguments args; + + // + // Methods + // + Params() { } + + Params(Arguments const &args_): args(args_) { } + }; + +private: + +public: + + CUTLASS_DEVICE + ApplyFinalReduction() { } + + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + apply(params, shared_storage); + } + +private: + + /// Partial reduction + CUTLASS_DEVICE + void apply(Params const ¶ms, SharedStorage &shared_storage) { + + int threadblock_num = (params.args.extent.column() + ThreadblockShape::kM - 1) / ThreadblockShape::kM; + + int block_n = blockIdx.x * blockDim.x; + + int thread_n = threadIdx.x; + + int idx_n = block_n + thread_n; + + if (idx_n >= params.args.extent.row()) { + return; + } + + using ConvertVarianceOutput = cutlass::NumericConverter; + using ConvertMeanOutput = cutlass::NumericConverter; + + using ConvertVariance = cutlass::NumericConverter; + using ConvertMean = cutlass::NumericConverter; + + using ConvertShiftK = cutlass::NumericConverter; + + ConvertVariance convert_variance; + ConvertMean convert_mean; + + ConvertVarianceOutput convert_variance_output; + ConvertMeanOutput convert_mean_output; + + ElementVariance *access_square = params.args.ref_Variance.data() + idx_n; + ElementMean *access_mean = params.args.ref_Mean.data() + idx_n; + + ElementVariance *access_square_bak = access_square; + ElementMean *access_mean_bak = access_mean; + + ElementLayernormCompute frag_square_sum = ElementLayernormCompute(0); + ElementLayernormCompute frag_element_sum = ElementLayernormCompute(0); + ElementVariance fetch_square; + ElementMean fetch_mean; + + CUTLASS_PRAGMA_UNROLL + for (int idx_m = 0; idx_m < threadblock_num; idx_m++) { + arch::global_load(fetch_square, access_square, true); + arch::global_load(fetch_mean, access_mean, true); + frag_element_sum += convert_mean(fetch_mean); + frag_square_sum += convert_variance(fetch_square); + access_square += params.args.extent.row(); + access_mean += params.args.extent.row(); + } + + ElementLayernormCompute mean = frag_element_sum; + ElementLayernormCompute square_mean = frag_square_sum; + + ElementLayernormCompute variance; + + if (kIsShiftedVariance && params.args.ptr_Shifted_K != nullptr) { + ElementOutput *access_shift_k = params.args.ptr_Shifted_K + idx_n; + ElementOutput fetch_shift_k; + ConvertShiftK convert_shift_k; + arch::global_load(fetch_shift_k, access_shift_k, true); + ElementLayernormCompute shifted_mean = mean - convert_shift_k(fetch_shift_k); + variance = cutlass::constants::one() / cutlass::fast_sqrt(square_mean - shifted_mean * shifted_mean + ElementLayernormCompute(1e-6)); + }else{ + variance = cutlass::constants::one() / cutlass::fast_sqrt(square_mean - mean * mean + ElementLayernormCompute(1e-6)); + } + + mean = -mean * variance; + + access_square = access_square_bak; + access_mean = access_mean_bak; + + access_square[0] = convert_variance_output(variance); + access_mean[0] = convert_mean_output(mean); + + } + +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ThreadblockShape_, + int ThreadCount, + typename OutputTileIterator_, + typename AccumulatorTile_, + typename ElementAccumulator_, + typename ElementVariance_, + typename ElementMean_, + typename ElementLayernormCompute_, + typename ElementwiseFunctor_, + bool IsShiftedVariance_ = false +> +class EpilogueVisitorLayerNorm { +public: + + using ElementVariance = ElementVariance_; + using ElementMean = ElementMean_; + using ElementLayernormCompute = ElementLayernormCompute_; + + using AccumulatorTile = AccumulatorTile_; + + using ThreadblockShape = ThreadblockShape_; + static int const kThreadCount = ThreadCount; + + using OutputTileIterator = OutputTileIterator_; + using ElementwiseFunctor = ElementwiseFunctor_; + + static int const kIterations = OutputTileIterator::kIterations; + static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; + static int const kRowIterations = OutputTileIterator::ThreadMap::Iterations::kRow; + + static int const kThreads = OutputTileIterator::ThreadMap::kThreads; + + static bool const kIsShiftedVariance = IsShiftedVariance_; + + using ElementOutput = typename OutputTileIterator::Element; + + static int const kDeltaRow = OutputTileIterator::ThreadMap::Delta::kRow; + + /// Array type used in Shift-K Layernorm + static int const kRowAccessCount = kIterations * kRowIterations; + + using ConvertedShiftFragment = Array; + + // Conducts manual transpose externally (already supported) for column major + using LayoutOutput = cutlass::layout::RowMajor; + + using ElementAccumulator = ElementAccumulator_; + + using AccumulatorFragment = Array; + using LayernormFragment = Array; + using OutputVector = Array; + using TensorRefD = TensorRef; + + static int const kThreadsPerRow = OutputTileIterator::ThreadMap::Detail::RowArrangement::Detail::kShapeWidth; + static int const kThreadsInColumn = kThreads / kThreadsPerRow; + static int const kHalfThreadsPerRow = (kThreadsPerRow >> 1); + + /// Argument structure + struct Arguments { + + typename ElementwiseFunctor::Params elementwise; + TensorRefD ref_C; + TensorRefD ref_D; + ElementVariance *ptr_Variance; + ElementMean *ptr_Mean; + ElementOutput *ptr_Shifted_K; + + // + // Methods + // + Arguments(): + ptr_Variance(nullptr), + ptr_Mean(nullptr), + ptr_Shifted_K(nullptr) + { + + } + + Arguments( + typename ElementwiseFunctor::Params elementwise_, + TensorRefD ref_C_, + TensorRefD ref_D_, + ElementVariance *ptr_Variance, + ElementMean *ptr_Mean_, + ElementOutput *ptr_Shifted_K_ = nullptr + ): + elementwise(elementwise_), + ref_C(ref_C_), + ref_D(ref_D_), + ptr_Variance(ptr_Variance), + ptr_Mean(ptr_Mean_), + ptr_Shifted_K(ptr_Shifted_K_) + { + + } + }; + + struct Params { + + typename ElementwiseFunctor::Params elementwise; + typename OutputTileIterator::Params params_C; + typename OutputTileIterator::Params params_D; + typename OutputTileIterator::Element *ptr_C; + typename OutputTileIterator::Element *ptr_D; + ElementVariance *ptr_Variance; + ElementMean *ptr_Mean; + ElementOutput *ptr_Shifted_K; + + // + // Methods + // + CUTLASS_HOST_DEVICE + Params(): + ptr_D(nullptr), + ptr_Variance(nullptr), + ptr_Mean(nullptr) + { + + } + + CUTLASS_HOST_DEVICE + Params(Arguments const &args): + elementwise(args.elementwise), + params_C(args.ref_C.layout()), + params_D(args.ref_D.layout()), + ptr_C(args.ref_C.data()), + ptr_D(args.ref_D.data()), + ptr_Variance(args.ptr_Variance), + ptr_Mean(args.ptr_Mean), + ptr_Shifted_K(args.ptr_Shifted_K) + { + + } + }; + + /// Shared storage + struct SharedStorage { + + }; + +private: + + Params const & params_; + SharedStorage & shared_storage_; + MatrixCoord extent_; + ElementwiseFunctor elementwise_; + + OutputTileIterator iterator_C_; + OutputTileIterator iterator_D_; + typename OutputTileIterator::Fragment fragment_C_; + typename OutputTileIterator::Fragment fragment_D_; + + ElementAccumulator alpha_; + ElementAccumulator beta_; + ConvertedShiftFragment shift_k_frag_; + + ElementLayernormCompute accum_sum_square_; + ElementLayernormCompute accum_sum_element_; + + MatrixCoord thread_offset_; + +public: + + CUTLASS_DEVICE + EpilogueVisitorLayerNorm( + Params const ¶ms, ///< Parameters routed to the epilogue + SharedStorage &shared_storage, ///< Shared storage needed by the functors here + MatrixCoord const &problem_size0, ///< Problem size of the output + int thread_idx, ///< Thread index within the threadblock + int warp_idx, ///< Warp index within the threadblock + int lane_idx, ///< Lane index within the warp + MatrixCoord const &threadblock_offset = MatrixCoord(0, 0) + ): + params_(params), + shared_storage_(shared_storage), + extent_(problem_size0), + elementwise_(params.elementwise), + iterator_C_(params.params_C, params.ptr_C, problem_size0, thread_idx, threadblock_offset), + iterator_D_(params.params_D, params.ptr_D, problem_size0, thread_idx, threadblock_offset) + { + alpha_ = (params.elementwise.alpha_ptr ? *params.elementwise.alpha_ptr : params.elementwise.alpha); + beta_ = (params.elementwise.beta_ptr ? *params.elementwise.beta_ptr : params.elementwise.beta); + + if (beta_ == ElementAccumulator()) { + iterator_C_.clear_mask(); + } + } + + /// Helper to indicate split-K behavior + CUTLASS_DEVICE + void set_k_partition( + int split_k_index, ///< Index of this threadblock within split-K partitioned scheme + int split_k_slices) { ///< Total number of split-K slices + + } + + /// Called to set the batch index + CUTLASS_DEVICE + void set_batch_index(int batch_idx) { + + } + + /// Called at the start of the epilogue just before iterating over accumulator slices + CUTLASS_DEVICE + void begin_epilogue() { + + // If shift-K feature is enabled, we load shift-k fragment + // at the very beginning of an epilogue + if (kIsShiftedVariance && params_.ptr_Shifted_K != nullptr) { + shift_k_frag_.clear(); + int thread_offset_row_base = iterator_D_.thread_start_row(); + + CUTLASS_PRAGMA_UNROLL + for (int iter_idx = 0; iter_idx < kIterations; ++iter_idx) { + int step_offset = iter_idx * OutputTileIterator::Shape::kRow; + CUTLASS_PRAGMA_UNROLL + for (int rid = 0; rid < kRowIterations; ++rid) { + int row_step_offset = rid * kDeltaRow; + int row_offset = thread_offset_row_base + step_offset + row_step_offset; + bool is_load = (row_offset < extent_.row()); + shift_k_frag_[iter_idx * kRowIterations + rid] = load_shift_k_(row_offset, is_load); + } + + } + + } + + } + + /// Called at the start of one step before starting accumulator exchange + CUTLASS_DEVICE + void begin_step(int step_idx) { + fragment_D_.clear(); + + if (elementwise_.kScale != cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) { + fragment_C_.clear(); + iterator_C_.load(fragment_C_); + ++iterator_C_; + } + } + + /// Called at the start of a row + CUTLASS_DEVICE + void begin_row(int row_idx) { + + } + + /// Called after accumulators have been exchanged for each accumulator vector + CUTLASS_DEVICE + void visit( + int iter_idx, + int row_idx, + int column_idx, + int frag_idx, + AccumulatorFragment const &accum) { + + using Mul = cutlass::multiplies; + using Minus = cutlass::minus; + using Exp = cutlass::fast_exp_op; + + [[maybe_unused]] Minus minus; + [[maybe_unused]] Mul mul; + [[maybe_unused]] Exp exponential; + + LayernormFragment result; + + thread_offset_ = + iterator_D_.thread_start() + + OutputTileIterator::ThreadMap::iteration_offset(frag_idx); + + NumericArrayConverter source_converter; + OutputVector &source_vector = reinterpret_cast(&fragment_C_)[frag_idx]; + + bool column_guard = (thread_offset_.column() < extent_.column()); + + if (elementwise_.kScale == cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling) { + result = source_converter(elementwise_(accum)); + }else{ + result = source_converter(elementwise_(accum, source_vector)); + } + + + ElementLayernormCompute inv_scalar = cutlass::constants::one() / ElementLayernormCompute(extent_.column()); + + // Fragment is cleared for non-reachable columns so no need to check against column guard + accum_sum_element_ = element_sum_accumulator_(result); + + // Square sum is different. Non-reachable columns should've been computed for shift-k + // Otherwise we will incorrectly have some extra k^2 added into square sum. + if (column_guard) { + accum_sum_square_ = (kIsShiftedVariance) ? \ + square_sum_accumulator_(result, shift_k_frag_[iter_idx * kRowIterations + row_idx]) : \ + square_sum_accumulator_(result); + } + else { + accum_sum_square_ = ElementLayernormCompute(0); + } + + accum_sum_element_ *= inv_scalar; + accum_sum_square_ *= inv_scalar; + + // After performing the in-thread reduction, we then perform cross-thread / in-warp reduction + CUTLASS_PRAGMA_UNROLL + for (int i = kHalfThreadsPerRow; i > 0; i >>= 1) { + accum_sum_element_ += __shfl_xor_sync(0xFFFFFFFF, accum_sum_element_, i); + accum_sum_square_ += __shfl_xor_sync(0xFFFFFFFF, accum_sum_square_, i); + } + + // Convert to the output + NumericArrayConverter output_converter; + OutputVector &output = reinterpret_cast(&fragment_D_)[frag_idx]; + output = output_converter(result); + } + + /// Called at the start of a row + CUTLASS_DEVICE + void end_row(int row_idx) { + + using ConvertVarianceOutput = cutlass::NumericConverter; + using ConvertMeanOutput = cutlass::NumericConverter; + + ConvertVarianceOutput convert_variance_output; + ConvertMeanOutput convert_mean_output; + + bool is_write_thread = (thread_offset_.row() < extent_.row() && (threadIdx.x % kThreadsPerRow) == 0); + int row_offset = thread_offset_.row() + blockIdx.y * extent_.row(); + + ElementVariance *curr_ptr_sum_square = params_.ptr_Variance + row_offset; + ElementMean *curr_ptr_element_sum = params_.ptr_Mean + row_offset; + + arch::global_store( + convert_variance_output(accum_sum_square_), + (void *)curr_ptr_sum_square, + is_write_thread); + + arch::global_store( + convert_mean_output(accum_sum_element_), + (void *)curr_ptr_element_sum, + is_write_thread); + + } + + /// Called after all accumulator elements have been visited + CUTLASS_DEVICE + void end_step(int step_idx) { + + iterator_D_.store(fragment_D_); + ++iterator_D_; + } + + /// Called after all steps have been completed + CUTLASS_DEVICE + void end_epilogue() { + + } + +private: + + CUTLASS_DEVICE + ElementLayernormCompute load_shift_k_(int row_offset, bool is_load) { + using ConvertShiftK = cutlass::NumericConverter; + ConvertShiftK convert_shift_k; + ElementOutput shift_k_val; + + // Computes the address to load shift_k element + ElementOutput *curr_ptr_shift_k = params_.ptr_Shifted_K + row_offset; + // Conditionally loads from global memory + arch::global_load(shift_k_val, (void *)curr_ptr_shift_k, is_load); + // Converts data type to return + ElementLayernormCompute converted_shift_k_val = convert_shift_k(shift_k_val); + + return converted_shift_k_val; + } + + CUTLASS_DEVICE + ElementLayernormCompute square_sum_accumulator_(LayernormFragment const &accum) { + ElementLayernormCompute sum_ = ElementLayernormCompute(0); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < LayernormFragment::kElements; ++i) { + auto accum_ = accum[i]; + sum_ += accum_ * accum_; + } + + return sum_; + } + + CUTLASS_DEVICE + ElementLayernormCompute square_sum_accumulator_(LayernormFragment const &accum, ElementLayernormCompute shift_k_val) { + ElementLayernormCompute sum_ = ElementLayernormCompute(0); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < LayernormFragment::kElements; ++i) { + auto accum_ = accum[i] - shift_k_val; + sum_ += accum_ * accum_; + } + + return sum_; + } + + CUTLASS_DEVICE + ElementLayernormCompute element_sum_accumulator_(LayernormFragment const &accum) { + ElementLayernormCompute sum_ = ElementLayernormCompute(0); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < LayernormFragment::kElements; ++i) { + sum_ += accum[i]; + } + + return sum_; + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// +template < + typename ElementInputA0_, + typename LayoutInputA0_, + typename ElementInputB0_, + typename LayoutInputB0_, + typename ElementOutput_, + typename LayoutOutput_, + typename ElementCompute_, + typename EpilogueFunctorOp_, + typename ThreadblockShape_, + typename WarpShape_, + typename InstructionShape_, + int Stages0, + int Stages1, + bool IsShiftedVariance_ = false +> +class GemmLayernorm { +public: + + /////////////////////////////////////////////////////////////////////////////////////////////// + + // + // Type definitions + // + + static bool const kInternalTranspose = cutlass::platform::is_same::value; + static bool const kIsShiftedVariance = IsShiftedVariance_; + + // These is mandatory layout. + using LayoutInputScaleBias = cutlass::layout::RowMajor; + + // These are mandatory data types. + using ElementLayernormCompute = float; + using ElementInputScaleBias = cutlass::half_t; + + // These are mandatory params required by mainloop fusion + using OperatorClass = cutlass::arch::OpClassTensorOp; + using ArchTag = cutlass::arch::Sm80; + + // These are mandatory layouts and data types + // that are inheritated from pre-defined params + + using LayoutSumSqr = LayoutInputScaleBias; + using LayoutSum = LayoutInputScaleBias; + + using ElementMean = ElementInputScaleBias; + using ElementVariance = ElementInputScaleBias; + + /////////////////////////////////////////////////////////////////////////////////////////////// + + using LayoutInputA0 = LayoutInputA0_; + using LayoutInputB0 = LayoutInputB0_; + using LayoutInputA1 = LayoutOutput_; + using LayoutInputB1 = LayoutOutput_; + using LayoutOutputC0 = LayoutOutput_; + using LayoutOutputC1 = LayoutOutput_; + + using ElementInputA0 = ElementInputA0_; + using ElementInputB0 = ElementInputB0_; + using ElementOutputC0 = ElementOutput_; + using ElementCompute = ElementCompute_; + using ElementInputB1 = ElementInputB0_; + + using ElementInputA1 = ElementOutputC0; + using ElementOutputC1 = ElementOutputC0; + + using EpilogueFunctorOp = EpilogueFunctorOp_; + + using TensorRefA = TensorRef; + using TensorRefB = TensorRef; + using TensorRefC = TensorRef; + using TensorVariance = TensorRef; + using TensorMean = TensorRef; + + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + + static int const kStages0 = Stages0; + static int const kStages1 = Stages1; + + using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; + + /////////////////////////////////////////////////////////////////////////////////////////////// + + using MapArguments = cutlass::gemm::kernel::detail::MapArguments< + ElementInputA0, + LayoutInputA0, + cutlass::ComplexTransform::kNone, + 128 / cutlass::sizeof_bits::value, + ElementInputB0, + LayoutInputB0, + cutlass::ComplexTransform::kNone, + 128 / cutlass::sizeof_bits::value, + LayoutOutputC0, + kInternalTranspose + >; + + using DefaultGemmKernel = typename cutlass::gemm::kernel::DefaultGemm< + typename MapArguments::ElementA, + typename MapArguments::LayoutA, + MapArguments::kAlignmentA, + typename MapArguments::ElementB, + typename MapArguments::LayoutB, + MapArguments::kAlignmentB, + ElementOutputC0, + typename MapArguments::LayoutC, + ElementCompute, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueFunctorOp, + SwizzleThreadBlock, + kStages0, + true, + typename cutlass::gemm::device::DefaultGemmConfiguration< + OperatorClass, ArchTag, ElementInputA0, ElementInputB0, ElementOutputC0, ElementCompute>::Operator, + cutlass::gemm::SharedMemoryClearOption::kNone + >::GemmKernel; + + /////////////////////////////////////////////////////////////////////////////////////////////// + + // Epilogue visitor + using EpilogueVisitor = kernel::EpilogueVisitorLayerNorm< + ThreadblockShape, + DefaultGemmKernel::kThreadCount, + typename DefaultGemmKernel::Epilogue::OutputTileIterator, + typename DefaultGemmKernel::Epilogue::AccumulatorFragmentIterator::AccumulatorTile, + ElementCompute, + ElementVariance, + ElementMean, + ElementLayernormCompute, + EpilogueFunctorOp, + kIsShiftedVariance + >; + + /// Epilogue + using Epilogue = typename cutlass::epilogue::threadblock::EpilogueWithVisitorFromExistingEpilogue< + EpilogueVisitor, + typename DefaultGemmKernel::Epilogue + >::Epilogue; + + // GEMM + using GemmEpilogueFusion = gemm::kernel::GemmWithEpilogueVisitor< + typename DefaultGemmKernel::Mma, + Epilogue, + SwizzleThreadBlock + >; + + using ApplyFinalReductionKernel = kernel::ApplyFinalReduction< + ElementVariance, + ElementMean, + ElementLayernormCompute, + ElementOutputC0, + ThreadblockShape, + kIsShiftedVariance + >; + +using GemmMainloopFusion = typename cutlass::gemm::device::GemmLayernormMainloopFusion< + ElementInputA1, LayoutInputA1, + ElementInputB1, LayoutInputB1, + ElementInputScaleBias, LayoutInputScaleBias, + ElementOutputC1, LayoutOutputC1, + ElementCompute, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueFunctorOp, + SwizzleThreadBlock, + kStages1 +>; + +public: + + /// Arguments class + struct Arguments { + + typename GemmEpilogueFusion::Arguments gemm0; + typename GemmMainloopFusion::Arguments gemm1; + typename ApplyFinalReductionKernel::Arguments reduction; + cutlass::gemm::GemmCoord extend; + + // + // Methods + // + Arguments() { } + + Arguments( + cutlass::gemm::GemmCoord problem_size0, + cutlass::gemm::GemmCoord problem_size1, + ElementInputA0 * ptr_A, + ElementInputB0 * ptr_B, + ElementOutputC0 * ptr_C, + ElementOutputC0 * ptr_D, + ElementOutputC0 * ptr_E, + ElementOutputC0 * ptr_O, + int64_t ldm_A, + int64_t ldm_B, + int64_t ldm_C, + int64_t ldm_D, + int64_t ldm_E, + int64_t ldm_O, + typename EpilogueFunctorOp::Params linear_scaling, + TensorVariance ref_Variance_, + TensorMean ref_Mean_, + TensorVariance ref_Gamma_, + TensorMean ref_Beta_, + ElementOutputC0 *ptr_Shifted_K = nullptr + ): + gemm0( + cutlass::gemm::GemmUniversalMode::kGemm, + {kInternalTranspose ? problem_size0.n() : problem_size0.m(),\ + kInternalTranspose ? problem_size0.m() : problem_size0.n(),\ + problem_size0.k()}, + {kInternalTranspose ? ptr_B : ptr_A, \ + kInternalTranspose ? ldm_B : ldm_A}, + {kInternalTranspose ? ptr_A : ptr_B, \ + kInternalTranspose ? ldm_A : ldm_B}, + typename EpilogueVisitor::Arguments( + linear_scaling, + {ptr_C, ldm_C}, + {ptr_D, ldm_D}, + ref_Variance_.data(), + ref_Mean_.data(), + ptr_Shifted_K + ) + ), + reduction( + MatrixCoord(kInternalTranspose ? problem_size0.n() : problem_size0.m(),\ + kInternalTranspose ? problem_size0.m() : problem_size0.n()), + ref_Variance_, + ref_Mean_, + ptr_Shifted_K + ), + gemm1( + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size1, + 1, + linear_scaling, + kInternalTranspose ? ptr_E : ptr_D, + kInternalTranspose ? ptr_D : ptr_E, + ref_Variance_.data(), + ref_Mean_.data(), + ref_Gamma_.data(), + ref_Beta_.data(), + ptr_O, + ptr_O, + problem_size1.m() * problem_size1.k(), + problem_size1.n() * problem_size1.k(), + problem_size1.n(), + problem_size1.n(), + problem_size1.k(), + problem_size1.k(), + problem_size1.m() * problem_size1.n(), + problem_size1.m() * problem_size1.n(), + kInternalTranspose ? ldm_E : ldm_D, + kInternalTranspose ? ldm_D : ldm_D, + ref_Variance_.layout().stride(0), + ref_Mean_.layout().stride(0), + ref_Gamma_.layout().stride(0), + ref_Beta_.layout().stride(0), + ldm_O, + ldm_O + ), + extend(problem_size0) + { + + } + }; + + struct Params { + + typename GemmEpilogueFusion::Params gemm0; + typename ApplyFinalReductionKernel::Params reduction; + MatrixCoord extend; + // + // Methods + // + Params() { } + + Params(Arguments const &args): + gemm0(args.gemm0), + reduction(args.reduction), + extend(MatrixCoord(args.extend.m(), args.extend.n())) + { + + } + }; + +public: + + // Gemm + + + // + // Methods + // + +private: + + Params params_; + GemmMainloopFusion gemm_fusion_op; + +public: + + /// Ctor + GemmLayernorm() { + + } + + /// Initialize + Status initialize(Arguments const &args) { + + params_ = Params(args); + cutlass::Status status; + size_t workspace_size = gemm_fusion_op.get_workspace_size(args.gemm1); + cutlass::device_memory::allocation workspace(workspace_size); + status = gemm_fusion_op.can_implement(args.gemm1); + CUTLASS_CHECK(status); + + status = gemm_fusion_op.initialize(args.gemm1, workspace.get()); + CUTLASS_CHECK(status); + + return cutlass::Status::kSuccess; + } + + /// Run + Status run(cudaStream_t stream) { + + // + // Launch the GEMM + layernorm kernel + // + + dim3 gemm_grid = SwizzleThreadBlock().get_grid_shape(params_.gemm0.grid_tiled_shape); + dim3 gemm_block(GemmEpilogueFusion::kThreadCount, 1, 1); + + int gemm_smem_size = int(sizeof(typename GemmEpilogueFusion::SharedStorage)); + + cutlass::Kernel<<>>(params_.gemm0); + + cudaError_t result = cudaGetLastError(); + + if (result != cudaSuccess) { + return cutlass::Status::kErrorInternal; + } + + // + // Launch the ApplyFinalReductionKernel + // + + // always performs reduction from leading dimension + int leading_dim_0 = kInternalTranspose ? params_.extend.row() : params_.extend.column(); + int leading_dim_1 = kInternalTranspose ? params_.extend.column() : params_.extend.row(); + + int thread_per_block = 128; + int block_per_row = (leading_dim_1 + thread_per_block - 1) / thread_per_block; + if (block_per_row < 4) { + thread_per_block = 32; + block_per_row = (leading_dim_1 + thread_per_block - 1) / thread_per_block; + } + + dim3 final_reduction_block(thread_per_block); + dim3 final_reduction_grid(block_per_row); + + Kernel<<< + final_reduction_grid, final_reduction_block, sizeof(typename ApplyFinalReductionKernel::SharedStorage), stream + >>>(params_.reduction); + + result = cudaGetLastError(); + + if (result != cudaSuccess) { + return cutlass::Status::kErrorInternal; + } + + // + // Launch the GEMM + mainloop fusion kernel + // + + cutlass::Status status = gemm_fusion_op(); + CUTLASS_CHECK(status); + + return cutlass::Status::kSuccess; + } + + /// Function call operator + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/38_syr2k_grouped/CMakeLists.txt b/examples/38_syr2k_grouped/CMakeLists.txt new file mode 100644 index 0000000000..461619ed3b --- /dev/null +++ b/examples/38_syr2k_grouped/CMakeLists.txt @@ -0,0 +1,36 @@ + +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + + +cutlass_example_add_executable( + 38_syr2k_grouped + syr2k_grouped.cu + ) + diff --git a/examples/38_syr2k_grouped/syr2k_grouped.cu b/examples/38_syr2k_grouped/syr2k_grouped.cu new file mode 100644 index 0000000000..c1fb82e839 --- /dev/null +++ b/examples/38_syr2k_grouped/syr2k_grouped.cu @@ -0,0 +1,1466 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief SYR2K Grouped Example. + + This workload computes a batch of SYR2K operations with distinct problem sizes. This example closely + follows 24_gemm_grouped. + + Examples: + + # Runs a grouped SYR2K with 100 random problem sizes + $ ./examples/38_syr2k_grouped/38_syr2k_grouped --groups=100 + + # Runs a grouped SYR2K with 100 random problem sizes (with SYR2K-K dimension equal to 1024) + $ ./examples/38_syr2k_grouped/24_gemm_grouped --groups=100 --k=1024 --verbose=true + + # Runs a grouped SYR2K that is equivalent to a batched SYR2K + $ ./examples/38_syr2k_grouped/38_syr2k_grouped --groups=100 --n=1024 --k=1024 --verbose=true + + # Execute grouped SYR2K and profile with NSight + $ nv-nsight-cu-cli ./examples/38_syr2k_grouped/38_syr2k_grouped --n=256 --k=256 --verbose=true \ + --iterations=1 --reference-check=false + +*/ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include +#include +#include +#include +#include +#include + +#include "cutlass/blas3.h" +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/gemm/kernel/rank_2k_grouped.h" +#include "cutlass/gemm/kernel/default_rank_2k_grouped.h" +#include "cutlass/gemm/device/rank_2k_grouped.h" +#include "cutlass/gemm/device/rank_2k.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/rank_2k_complex.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_norm.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Result structure +struct Result { + + double runtime_ms; + double initialization_time_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + // + // Methods + // + + Result( + double runtime_ms = 0, + double initialization_time_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess + ): + runtime_ms(runtime_ms), initialization_time_ms(initialization_time_ms), gflops(gflops), + status(status), error(error), passed(true) { } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + bool error; + bool reference_check; + bool profile_initialization; + bool sort_problems; + + std::vector problem_sizes; + + int alignment; + int problem_count; + int iterations; + int cuda_streams; + bool verbose; + float alpha; + float beta; + std::string benchmark_path; + + std::string output_tag; + std::ofstream output_file; + + using GroupScheduleMode = cutlass::gemm::kernel::GroupScheduleMode; + std::vector scheduler_modes; + + std::unordered_map + str_to_scheduler_mode = { + {"kDeviceOnly", GroupScheduleMode::kDeviceOnly}, + {"kHostPrecompute", GroupScheduleMode::kHostPrecompute} + }; + + struct GroupScheduleModeHash { + size_t operator()(GroupScheduleMode m) const { + return static_cast(m); + } + }; + + std::unordered_map + scheduler_mode_to_str = { + {GroupScheduleMode::kDeviceOnly, "kDeviceOnly"}, + {GroupScheduleMode::kHostPrecompute, "kHostPrecompute"} + }; + + std::vector all_scheduler_modes = {GroupScheduleMode::kDeviceOnly, GroupScheduleMode::kHostPrecompute}; + + // + // Methods + // + + Options(): + help(false), + error(false), + alignment(8), + reference_check(true), + profile_initialization(false), + sort_problems(false), + problem_count(5), + iterations(20), + cuda_streams(0), + verbose(false), + alpha(1), + beta(), + scheduler_modes({GroupScheduleMode::kDeviceOnly}) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("alignment", alignment, 8); + cmd.get_cmd_line_argument("groups", problem_count, 5); + cmd.get_cmd_line_argument("alpha", alpha, 1.0f); + cmd.get_cmd_line_argument("beta", beta, 0.0f); + cmd.get_cmd_line_argument("iterations", iterations, 20); + cmd.get_cmd_line_argument("streams", cuda_streams, 0); + cmd.get_cmd_line_argument("verbose", verbose, false); + cmd.get_cmd_line_argument("reference-check", reference_check, true); + cmd.get_cmd_line_argument("profile-initialization", profile_initialization, false); + cmd.get_cmd_line_argument("sort-problems", sort_problems, false); + cmd.get_cmd_line_argument("benchmark", benchmark_path); + + std::vector scheduler_mode_strs; + cmd.get_cmd_line_arguments("scheduler-modes", scheduler_mode_strs); + + if (!scheduler_mode_strs.empty()) { + scheduler_modes.clear(); + if (scheduler_mode_strs.size() == 1 && scheduler_mode_strs[0] == "all") { + scheduler_modes = all_scheduler_modes; + } else { + for (std::string precomp_str : scheduler_mode_strs) { + auto it = str_to_scheduler_mode.find(precomp_str); + if (it != str_to_scheduler_mode.end()) { + scheduler_modes.push_back(it->second); + } else if (precomp_str == "all") { + std::cerr << "Flag --scheduler-modes=all must not contain other scheduler modes in list." << std::endl; + error = true; + return; + } else { + std::cerr << "Unrecognized scheduler mode '" << precomp_str << "'" << std::endl; + error = true; + return; + } + } + } + } + + std::string output_path; + cmd.get_cmd_line_argument("tag", output_tag); + cmd.get_cmd_line_argument("output_file", output_path); + + if (!output_path.empty()) { + + std::ios_base::openmode open_mode = std::ios_base::out; + + std::ifstream input_file(output_path.c_str()); + + if (input_file.good()) { + open_mode = std::ios_base::app; + input_file.close(); + } + + output_file.open(output_path.c_str(), open_mode); + + if (output_file.good() && open_mode != std::ios_base::app) { + output_file << "Tag,Provider,Kind,Groups,Runtime,GFLOPs\n"; + } + } + + // Decide how to initialize the problems + if (!benchmark_path.empty()) { + if (!benchmark_problems()) { + error = true; + problem_sizes.clear(); + return; + } + } + else { + randomize_problems(cmd); + } + } + + void randomize_problems(cutlass::CommandLine &cmd) { + + // + // For now, randomly choose the problem sizes. + // + + int cmd_line_m = -1; + int cmd_line_n = -1; + int cmd_line_k = -1; + + cmd.get_cmd_line_argument("m", cmd_line_m); + cmd.get_cmd_line_argument("n", cmd_line_n); + cmd.get_cmd_line_argument("k", cmd_line_k); + + // SYR2K is defined via only N and K. + if (cmd_line_m != -1) { + std::cerr << "Parameter M is ignored for SYR2K\n"; + error = true; + return; + } + + problem_sizes.reserve(problem_count); + + for (int i = 0; i < problem_count; ++i) { + int n = cmd_line_n; + int k = cmd_line_k; + + if (n < 1) { + n = alignment * ((rand() % 256) + 1); + } + + if (k < 1) { + k = alignment * ((rand() % 256) + 1); + } + + // SYR2K is defined only in terms of N and K. Replicate N into + // the SYR2K-N dimension. + cutlass::gemm::GemmCoord problem(n, n, k); + + problem_sizes.push_back(problem); + } + } + + /// Load a benchmark + bool benchmark_problems() { + std::ifstream file(benchmark_path); + if (!file.good()) { + return false; + } + + while (file.good()) { + + int idx = -1; + std::string extent_str; + + file >> idx >> extent_str; + + if (idx < 0 || extent_str.empty()) { + break; + } + + cutlass::gemm::GemmCoord extent; + std::vector tokens; + + cutlass::CommandLine::tokenize(tokens, extent_str, 'x'); + + for (int i = 0; i < int(tokens.size()); ++i) { + int x = std::atoi(tokens.at(i).c_str()); + + // round up + if (x % alignment) { + x += (alignment - (x % alignment)); + } + + extent.at(i) = x; + } + + if (extent.product()) { + problem_sizes.push_back(extent); + } + } + + problem_count = int(problem_sizes.size()); + return true; + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "38_syr2k_grouped\n\n" + << " This example profiles the performance of a 'grouped' SYR2K kernel. This example closely follows 24_gemm_grouped\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement.\n\n" + << " --benchmark= Executes a benchmark problem size.\n" + << " --output_file= Path to a CSV file to output results. If it exists already, results are appended.\n" + << " --tag= String tag to prepend to the CSV file.\n" + << " --groups= Number of individual SYR2K problems (default: --groups=15)\n" + << " --m= Sets the M dimension for all groups. Otherwise, it is selected randomly\n" + << " --n= Sets the N dimension for all groups. Otherwise, it is selected randomly\n" + << " --k= Sets the K dimension for all groups. Otherwise, it is selected randomly\n" + << " --alpha= Epilogue scalar alpha (real part)\n" + << " --beta= Epilogue scalar beta (real part)\n" + << " --scheduler-modes= List of scheduler modes to be profile for grouped GEMM scheduler (default: --scheduler_modes=kDeviceOnly)\n" + << " --iterations= Number of profiling iterations to perform.\n" + << " --reference-check= If true, performs reference check.\n" + << " --verbose= If true, prints problem sizes and batching structure.\n" + << " --profile-initialization= If true, profiles the device-level kernel's initialization.\n" + << " --sort-problems= If true, sorts problem sizes in descending order of SYR2K-K dimension.\n"; + + out << "\n\nExamples:\n\n" + + << "# Runs a grouped SYR2K with 100 random problem sizes\n" + << "$ ./examples/38_syr2k_grouped/38_syr2k_grouped --groups=100\n\n" + + << "# Runs a grouped SYR2K with 100 random problem sizes (with K dimension equal to 1024)\n" + << "$ ./examples/38_syr2k_grouped/38_syr2k_grouped --groups=100 --k=1024 --verbose=true\n\n" + + << "# Runs a grouped SYR2K that is equivalent to a batched SYR2K\n" + << "$ ./examples/38_syr2k_grouped/38_syr2k_grouped --groups=100 --n=1024 --k=1024 --verbose=true\n\n" + + << "# Runs a grouped SYR2K with each different scheduler mode\n" + << "$ ./examples/38_syr2k_grouped/38_syr2k_grouped --scheduler-modes=all\n\n" + + << "# Runs a grouped SYR2K with each different scheduler mode and profiles host-side initialization time\n" + << "$ ./examples/38_syr2k_grouped/38_syr2k_grouped --scheduler-modes=all --profile-initialization=true\n\n" + + << "# Runs a grouped SYR2K problem given an externally supplied benchmark file. This is a text file in which\n" + << "# Each line contains a unique group index and an MxNxK triple indicating problemsize. NOTE that the\n" + << "# GEMM-M and GEMM-N dimensions must match.\n" + << "#\n" + << "# For example, assume the following are the contents of 'problems.txt'\n" + << "#\n" + << "# 0 256x256x520\n" + << "# 1 264x264x1024\n" + << "# 2 48x48x1024\n" + << "#\n" + << "$ ./examples/38_syr2k_grouped/38_syr2k_grouped --benchmark=problems.txt\n\n" + + << "# Execute Grouped SYR2K and profile with NSight\n" + << "$ nv-nsight-cu-cli ./examples/38_syr2k_grouped/38_syr2k_grouped --n=256 --k=256 --verbose=true --iterations=1 --reference-check=false\n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const { + + // Number of real-valued multiply-adds + int64_t fmas = int64_t(); + + for (auto const & problem : problem_sizes) { + fmas += problem.product(); + } + + // SYR2K is defined as (A x BT) + (B x AT), so the number of FMAs is twice that in a GEMM + fmas *= 2; + + // Two flops per multiply-add + return 2.0 * double(fmas) / double(1.0e9) / runtime_s; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class BaseTestbed { +public: + // + // Type definitions + // + + using ElementA = typename Rank2K::ElementA; + using ElementB = typename Rank2K::ElementB; + using ElementC = typename Rank2K::ElementC; + using ElementAccumulator = typename Rank2K::ElementAccumulator; + + using EpilogueOutputOp = typename Rank2K::Rank2Kkernel::Epilogue::OutputOp; + using ElementCompute = typename EpilogueOutputOp::ElementCompute; + + using LayoutA = typename Rank2K::LayoutA; + using LayoutB = typename Rank2K::LayoutB; + using LayoutC = typename Rank2K::LayoutC; + + using MatrixCoord = typename LayoutC::TensorCoord; + + // + // Data members + // + + Options & options; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint32_t seed; + + cutlass::DeviceAllocation problem_sizes_device; + + std::vector offset_A; + std::vector offset_B; + std::vector offset_C; + std::vector offset_D; + + std::vector lda_host; + std::vector ldb_host; + std::vector ldc_host; + std::vector ldd_host; + + cutlass::DeviceAllocation lda; + cutlass::DeviceAllocation ldb; + cutlass::DeviceAllocation ldc; + cutlass::DeviceAllocation ldd; + + cutlass::DeviceAllocation block_A; + cutlass::DeviceAllocation block_B; + cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_D; + + cutlass::DeviceAllocation ptr_A; + cutlass::DeviceAllocation ptr_B; + cutlass::DeviceAllocation ptr_C; + cutlass::DeviceAllocation ptr_D; + + BaseTestbed( + Options &options_, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint32_t seed_ = 3080 + ): + options(options_), init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + int problem_count() const { + return options.problem_count; + } + + /// Helper to initialize a tensor view + template + void initialize_tensor( + Element *ptr, + size_t capacity, + cutlass::Distribution::Kind dist_kind, + uint32_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + Element scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + if (cutlass::sizeof_bits::value <= 16) { + scope_max = 5; + scope_min = -5; + } + else { + scope_max = 8; + scope_min = -8; + } + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::device::BlockFillRandomUniform( + ptr, capacity, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::device::BlockFillRandomGaussian( + ptr, capacity, seed, Element(), Element(0.5f)); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + // Fill with increasing elements + cutlass::reference::device::BlockFillSequential( + ptr, capacity, Element(1), Element()); + } + else { + + // Fill with all 1s + cutlass::reference::device::BlockFillSequential( + ptr, capacity, Element(), Element(1)); + } + } + + /// Allocates device-side data + void allocate() { + int64_t total_elements_A = 0; + int64_t total_elements_B = 0; + int64_t total_elements_C = 0; + int64_t total_elements_D = 0; + + lda_host.resize(problem_count()); + ldb_host.resize(problem_count()); + ldc_host.resize(problem_count()); + ldd_host.resize(problem_count()); + + for (int32_t i = 0; i < problem_count(); ++i) { + + auto problem = options.problem_sizes.at(i); + + lda_host.at(i) = LayoutA::packed({problem.n(), problem.k()}).stride(0); + ldb_host.at(i) = LayoutB::packed({problem.n(), problem.k()}).stride(0); + ldc_host.at(i) = LayoutC::packed({problem.n(), problem.n()}).stride(0); + ldd_host.at(i) = LayoutC::packed({problem.n(), problem.n()}).stride(0); + + offset_A.push_back(total_elements_A); + offset_B.push_back(total_elements_B); + offset_C.push_back(total_elements_C); + offset_D.push_back(total_elements_D); + + int64_t elements_A = problem.n() * problem.k(); + int64_t elements_B = problem.n() * problem.k(); + int64_t elements_C = problem.n() * problem.n(); + int64_t elements_D = problem.n() * problem.n(); + + total_elements_A += elements_A; + total_elements_B += elements_B; + total_elements_C += elements_C; + total_elements_D += elements_D; + } + + lda.reset(problem_count()); + ldb.reset(problem_count()); + ldc.reset(problem_count()); + ldd.reset(problem_count()); + + block_A.reset(total_elements_A); + block_B.reset(total_elements_B); + block_C.reset(total_elements_C); + block_D.reset(total_elements_D); + } + + /// Initializes device-side data + void initialize() { + problem_sizes_device.reset(problem_count()); + problem_sizes_device.copy_from_host(options.problem_sizes.data()); + + lda.copy_from_host(lda_host.data()); + ldb.copy_from_host(ldb_host.data()); + ldc.copy_from_host(ldc_host.data()); + ldd.copy_from_host(ldd_host.data()); + + // + // Assign pointers + // + + std::vector ptr_A_host(problem_count()); + std::vector ptr_B_host(problem_count()); + std::vector ptr_C_host(problem_count()); + std::vector ptr_D_host(problem_count()); + + for (int32_t i = 0; i < problem_count(); ++i) { + ptr_A_host.at(i) = block_A.get() + offset_A.at(i); + ptr_B_host.at(i) = block_B.get() + offset_B.at(i); + ptr_C_host.at(i) = block_C.get() + offset_C.at(i); + ptr_D_host.at(i) = block_D.get() + offset_D.at(i); + } + + ptr_A.reset(problem_count()); + ptr_A.copy_from_host(ptr_A_host.data()); + + ptr_B.reset(problem_count()); + ptr_B.copy_from_host(ptr_B_host.data()); + + ptr_C.reset(problem_count()); + ptr_C.copy_from_host(ptr_C_host.data()); + + ptr_D.reset(problem_count()); + ptr_D.copy_from_host(ptr_D_host.data()); + + // + // Initialize the problems of the workspace + // + + initialize_tensor(block_A.get(), block_A.size(), init_A, seed * 2021); + initialize_tensor(block_B.get(), block_B.size(), init_B, seed * 2022); + initialize_tensor(block_C.get(), block_C.size(), init_C, seed * 2023); + + cutlass::reference::device::BlockFillSequential( + block_D.get(), block_D.size(), ElementC(), ElementC()); + } + + /// Verifies the result is a SYR2K + bool verify() { + + bool passed = true; + + for (int32_t i = 0; i < problem_count(); ++i) { + cutlass::gemm::GemmCoord problem = options.problem_sizes.at(i); + + LayoutA layout_A(lda_host.at(i)); + LayoutB layout_B(ldb_host.at(i)); + LayoutC layout_C(ldc_host.at(i)); + LayoutC layout_D(ldd_host.at(i)); + + cutlass::HostTensor host_A( + typename LayoutA::TensorCoord(problem.n(), problem.k()), /*device_backed=*/false); + cutlass::HostTensor host_B( + typename LayoutB::TensorCoord(problem.n(), problem.k()), /*device_backed=*/false); + cutlass::HostTensor host_C( + typename LayoutC::TensorCoord(problem.n(), problem.n()), /*device_backed=*/false); + cutlass::HostTensor host_D( + typename LayoutC::TensorCoord(problem.n(), problem.n()), /*device_backed=*/false); + + cutlass::device_memory::copy_to_host(host_A.host_data(), block_A.get() + offset_A.at(i), problem.n() * problem.k()); + cutlass::device_memory::copy_to_host(host_B.host_data(), block_B.get() + offset_B.at(i), problem.n() * problem.k()); + cutlass::device_memory::copy_to_host(host_C.host_data(), block_C.get() + offset_C.at(i), problem.n() * problem.n()); + cutlass::reference::host::BlockFillSequential( + host_D.host_data(), problem.n() * problem.n(), ElementC(), ElementC()); + + MatrixCoord extent_C{problem.n(), problem.n()}; + + // Reference Rank2K + cutlass::reference::host::Rank2KComplex< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementC, ElementAccumulator + >( + problem, + (double)options.alpha, + host_A.host_view(), + Rank2K::kTransformA, + host_B.host_view(), + Rank2K::kTransformB, + (double)options.beta, + host_C.host_view(), + host_D.host_view(), + ElementAccumulator(0), + Rank2K::kFillModeC, + Rank2K::kBlasMode + ); + + // Copy to host memory + std::vector matrix_D(layout_D.capacity(extent_C)); + cutlass::device_memory::copy_to_host(matrix_D.data(), block_D.get() + offset_D.at(i), matrix_D.size()); + + cutlass::TensorView view_D(matrix_D.data(), layout_D, extent_C); + cutlass::TensorView view_Ref = host_D.host_view(); + + // Reference check + passed = cutlass::reference::host::TensorEquals(view_D, view_Ref); + + if (!passed) { + std::cerr << "\n***\nError - problem " << i << " failed the QA check\n***\n" << std::endl; + return passed; + } + } + + return passed; + } +}; + +template +class TestbedConventional : BaseTestbed { +public: + TestbedConventional( + Options &options_, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint32_t seed_ = 3080 + ): BaseTestbed(options_, init_A_, init_B_, init_C_, seed_) {} + + /// Verbose printing of problem sizes + void print_problem_sizes() { + + // Print groups + std::cout << this->problem_count() << " groups:\n"; + + int32_t idx = 0; + int64_t total_tiles = 0; + + for (auto const & problem : this->options.problem_sizes) { + int tiles = + ((problem.m() + Rank2K::ThreadblockShape::kM - 1) / Rank2K::ThreadblockShape::kM) * + ((problem.n() + Rank2K::ThreadblockShape::kN - 1) / Rank2K::ThreadblockShape::kN); + + total_tiles += tiles; + + std::cout << " [" << idx << "]: " + << problem.m() << "-by-" << problem.n() << "-by-" << problem.k() + << " (" << tiles << " threadblock tiles)" << "\n"; + + ++idx; + } + std::cout << std::endl; + } + + /// Executes a conventional SYR2K kernel. + Result profile() { + std::cout << "Conventional Rank2K:\n" + << "====================================================" << std::endl; + + Result result; + result.passed = false; + + // Initialize the problem + this->allocate(); + this->initialize(); + + if (this->options.verbose) { + print_problem_sizes(); + } + + // + // Create CUDA streams to maximize concurrency of SYR2K kernels + // + int32_t effective_streams = (this->options.cuda_streams ? this->options.cuda_streams : 1); + std::vector cuda_streams; + char const *provider = "CUTLASS"; + + // + // Warmup run + // + + if (this->options.cuda_streams) { + for (int i = 0; i < this->options.cuda_streams; ++i) { + cudaStream_t stream; + + result.error = cudaStreamCreate(&stream); + if (result.error != cudaSuccess) { + std::cerr << "Failed to create CUDA stream." << std::endl; + return result; + } + cuda_streams.push_back(stream); + } + } + else { + cuda_streams.push_back(nullptr); + } + + // Use 'D' for the in/out workspace + this->block_D.copy_from_device(this->block_C.get()); + + for (size_t i = 0; i < this->options.problem_sizes.size(); ++i) { + cutlass::gemm::GemmCoord const & problem = this->options.problem_sizes[i]; + int32_t batch_count = 1; + int64_t lda = this->lda_host.at(i); + int64_t ldb = this->ldb_host.at(i); + int64_t ldc = this->ldc_host.at(i); + typename Rank2K::ElementA* ptrA = this->block_A.get() + this->offset_A.at(i); + typename Rank2K::ElementB* ptrB = this->block_B.get() + this->offset_B.at(i); + typename Rank2K::ElementC* ptrC = this->block_C.get() + this->offset_C.at(i); + typename Rank2K::ElementC* ptrD = this->block_D.get() + this->offset_D.at(i); + + // + // Initialize the CUTLASS SYR2K operator + // + + // Configure the SYR2K arguments + typename Rank2K::EpilogueOutputOp::Params epilogue_op(this->options.alpha, this->options.beta); + + typename Rank2K::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem, + batch_count, + epilogue_op, + (void const *)ptrA, + (void const *)ptrB, + (void const *)ptrC, + (void *)ptrD, + int64_t(), + int64_t(), + int64_t(), + int64_t(), + int64_t(lda), + int64_t(ldb), + int64_t(ldc), + int64_t(ldc) + }; + + Rank2K rank2k_op; + + cutlass::Status status = rank2k_op.initialize(arguments); + + if (status != cutlass::Status::kSuccess) { + std::cerr << "CUTLASS error on line " << __LINE__ << std::endl; + return result; + } + + status = rank2k_op(); + + if (status != cutlass::Status::kSuccess) { + std::cerr << "CUTLASS error on line " << __LINE__ << std::endl; + return result; + } + } + + // + // Wait for completion + // + + result.error = cudaDeviceSynchronize(); + + if (result.error != cudaSuccess) { + std::cerr << "Kernel execution error: " << cudaGetErrorString(result.error); + return result; + } + + // + // Construct events + // + + cudaEvent_t events[2]; + + for (auto & event : events) { + result.error = cudaEventCreate(&event); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; + return -1; + } + } + + // + // Wait for completion + // + + result.error = cudaDeviceSynchronize(); + + if (result.error != cudaSuccess) { + std::cerr << "Kernel execution error: " << cudaGetErrorString(result.error); + return result; + } + + // Record an event at the start of a series of SYR2K operations + result.error = cudaEventRecord(events[0]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // + // Run profiling loop + // + + size_t last_stream_idx = 0; + + for (int iter = 0; iter < this->options.iterations; ++iter) { + for (size_t i = 0; i < this->options.problem_sizes.size(); ++i) { + cutlass::gemm::GemmCoord const & problem = this->options.problem_sizes[i]; + int32_t batch_count = 1; + int64_t lda = this->lda_host.at(i); + int64_t ldb = this->ldb_host.at(i); + int64_t ldc = this->ldc_host.at(i); + typename Rank2K::ElementA* ptrA = this->block_A.get() + this->offset_A.at(i); + typename Rank2K::ElementB* ptrB = this->block_B.get() + this->offset_B.at(i); + typename Rank2K::ElementC* ptrC = this->block_C.get() + this->offset_C.at(i); + typename Rank2K::ElementC* ptrD = this->block_D.get() + this->offset_D.at(i); + + last_stream_idx = (i % effective_streams); + + // + // Initialize the CUTLASS SYR2K operator + // + + // Configure the SYR2K arguments + typename Rank2K::EpilogueOutputOp::Params epilogue_op(this->options.alpha, this->options.beta); + + typename Rank2K::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem, + batch_count, + epilogue_op, + (void const *)ptrA, + (void const *)ptrB, + (void const *)ptrC, + (void *)ptrD, + int64_t(), + int64_t(), + int64_t(), + int64_t(), + int64_t(lda), + int64_t(ldb), + int64_t(ldc), + int64_t(ldc) + }; + + Rank2K rank2k_op; + + cutlass::Status status = rank2k_op.initialize(arguments); + + if (status != cutlass::Status::kSuccess) { + std::cerr << "CUTLASS error on line " << __LINE__ << std::endl; + return result; + } + + status = rank2k_op(cuda_streams[last_stream_idx]); + + if (status != cutlass::Status::kSuccess) { + std::cerr << "CUTLASS error on line " << __LINE__ << std::endl; + return result; + } + } + } + + // + // Stop profiling loop + // + + // Record an event when the SYR2K operations have been launched. + result.error = cudaEventRecord(events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // + // Wait for work to be completed + // + + result.error = cudaDeviceSynchronize(); + + if (result.error != cudaSuccess) { + std::cerr << "Kernel execution error: " << cudaGetErrorString(result.error); + return result; + } + + // Measure elapsed runtime + float runtime_ms = 0; + result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Compute average runtime and GFLOPs. + result.runtime_ms = double(runtime_ms) / double(this->options.iterations); + result.gflops = this->options.gflops(result.runtime_ms / 1000.0); + + // + // Cleanup + // + + for (auto event : events) { + (void)cudaEventDestroy(event); + } + + for (auto stream : cuda_streams) { + if (stream) { + (void)cudaStreamDestroy(stream); + } + } + + std::cout << " " << this->options.problem_sizes.size() << " conventional Rank2Ks launched" << std::endl; + std::cout << std::endl; + std::cout << " " << "Conventional Runtime: " << result.runtime_ms << " ms" << std::endl; + std::cout << " " << "Conventional GFLOPS: " << result.gflops << std::endl; + + if (this->options.output_file.good()) { + this->options.output_file << this->options.output_tag << "," << provider << ",conventional," + << this->problem_count() << "," << result.runtime_ms << "," << result.gflops << std::endl; + } + + result.passed = true; + return result; + } +}; + +template +class TestbedGrouped : BaseTestbed { +public: + TestbedGrouped( + Options &options_, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint32_t seed_ = 3080 + ) : BaseTestbed(options_, init_A_, init_B_, init_C_, seed_) {} + + // Redefine Rank2K with different GroupScheduleMode_ + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + typename Rank2K_::ElementA, typename Rank2K_::LayoutA, Rank2K_::kTransformA, Rank2K_::kAlignmentA, + typename Rank2K_::ElementB, typename Rank2K_::LayoutB, Rank2K_::kTransformB, Rank2K_::kAlignmentB, + typename Rank2K_::ElementC, typename Rank2K_::LayoutC, Rank2K_::kFillModeC, + typename Rank2K_::ElementAccumulator, + typename Rank2K_::OperatorClass, + typename Rank2K_::ArchTag, + typename Rank2K_::ThreadblockShape, + typename Rank2K_::WarpShape, + typename Rank2K_::InstructionShape, + typename Rank2K_::EpilogueOutputOp, + typename Rank2K_::ThreadblockSwizzle, + Rank2K_::kStages, + typename Rank2K_::Operator::ArchMmaOperator::Operator, + Rank2K_::kBlasMode, + GroupScheduleMode_>::Rank2Kkernel; + + using Rank2K = cutlass::gemm::device::Rank2KGrouped; + + /// Verbose printing of problem sizes + void print_problem_sizes() { + + // Print groups + std::cout << this->problem_count() << " groups:\n"; + + int32_t idx = 0; + int64_t total_tiles = 0; + + for (auto const & problem : this->options.problem_sizes) { + int tiles = Rank2K::problem_tile_count(problem); + total_tiles += tiles; + + std::cout << " [" << idx << "]: " + << problem.m() << "-by-" << problem.n() << "-by-" << problem.k() + << " (" << tiles << " threadblock tiles)" << "\n"; + + ++idx; + } + std::cout << std::endl; + } + + /// Sort problems in descending order of problem-K dimension + void sort_problems() { + Rank2K::sort_problems(this->options.problem_count, + this->options.problem_sizes.data(), + this->lda_host.data(), + this->ldb_host.data(), + this->ldc_host.data(), + this->ldd_host.data(), + this->offset_A.data(), + this->offset_B.data(), + this->offset_C.data(), + this->offset_D.data()); + } + + /// Executes a grouped kernel and measures runtime. + Result profile() { + std::string sched_mode = this->options.scheduler_mode_to_str.find(GroupScheduleMode_)->second; + std::cout << std::endl; + std::cout << "Grouped Rank2K (CUTLASS) with mode " << sched_mode << ":\n" + << "====================================================" << std::endl; + + Result result; + + int threadblock_count = Rank2K::sufficient(this->options.problem_sizes.data(), this->options.problem_count); + + // Early exit + if (!threadblock_count) { + std::cout << "Active CUDA device lacks hardware resources to run CUTLASS Grouped SYR2K kernel." << std::endl; + return result; + } + + result.passed = false; + + // Initialize the problem + this->allocate(); + if (this->options.sort_problems) { + sort_problems(); + } + this->initialize(); + + if (this->options.verbose) { + print_problem_sizes(); + } + + // Configure the Rank2K arguments + typename Rank2K::EpilogueOutputOp::Params epilogue_op(this->options.alpha, this->options.beta); + + // Configure Rank2K arguments + typename Rank2K::Arguments args( + cutlass::gemm::GemmUniversalMode::kGemm, + this->problem_sizes_device.get(), + this->problem_count(), + threadblock_count, + epilogue_op, + this->ptr_A.get(), + this->ptr_B.get(), + this->ptr_C.get(), + this->ptr_D.get(), + this->lda.get(), + this->ldb.get(), + this->ldc.get(), + this->ldd.get(), + this->options.problem_sizes.data() + ); + + // Initialize the Rank2K object + Rank2K rank2k{}; + size_t workspace_size = rank2k.get_workspace_size(args); + cutlass::DeviceAllocation workspace(workspace_size); + + result.status = rank2k.initialize(args, workspace.get()); + + if (result.status != cutlass::Status::kSuccess) { + std::cerr << "Failed to initialize CUTLASS Grouped Rank2K kernel." << std::endl; + return result; + } + + // Run the grouped Rank2K object + result.status = rank2k.run(); + + if (result.status != cutlass::Status::kSuccess) { + std::cerr << "Failed to run CUTLASS Grouped Rank2K kernel." << std::endl; + return result; + } + + // Wait for completion + result.error = cudaDeviceSynchronize(); + + if (result.error != cudaSuccess) { + std::cerr << "Kernel execution error: " << cudaGetErrorString(result.error); + return result; + } + + // + // Verify correctness + // + result.passed = true; + if (this->options.reference_check) { + result.passed = this->verify(); + } + + // + // Warm-up run of the grouped Rank2K object + // + result.status = rank2k.run(); + + if (result.status != cutlass::Status::kSuccess) { + std::cerr << "Failed to run CUTLASS Grouped Rank2K kernel." << std::endl; + return result; + } + + // + // Construct events + // + + cudaEvent_t events[2]; + + for (auto & event : events) { + result.error = cudaEventCreate(&event); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; + return -1; + } + } + + // Record an event at the start of a series of SYR2K operations + result.error = cudaEventRecord(events[0]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // + // Run profiling loop + // + + for (int iter = 0; iter < this->options.iterations; ++iter) { + rank2k(); + } + + // + // Stop profiling loop + // + + // Record an event when the Rank2K operations have been launched. + result.error = cudaEventRecord(events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Wait for work on the device to complete. + result.error = cudaEventSynchronize(events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Measure elapsed runtime + float runtime_ms = 0; + result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Compute average runtime and GFLOPs. + result.runtime_ms = double(runtime_ms) / double(this->options.iterations); + result.gflops = this->options.gflops(result.runtime_ms / 1000.0); + + // + // Cleanup + // + + for (auto event : events) { + (void)cudaEventDestroy(event); + } + + // Optionally profile initialization + if (this->options.profile_initialization) { + // Warm up + rank2k.initialize(args, workspace.get()); + + auto start_time = std::chrono::high_resolution_clock::now(); + for (int32_t i = 0; i < this->options.iterations; ++i) { + rank2k.initialize(args, workspace.get()); + } + auto end_time = std::chrono::high_resolution_clock::now(); + + std::chrono::duration duration = end_time - start_time; + duration /= double(this->options.iterations); + result.initialization_time_ms = duration.count(); + } + + int64_t total_tiles = Rank2K::group_tile_count(args); + std::cout << " " << total_tiles << " total threadblock tiles." << std::endl; + + std::cout << std::endl; + std::cout << " " << "Grouped Runtime: " << result.runtime_ms << " ms" << std::endl; + std::cout << " " << "Grouped GFLOPs: " << result.gflops << std::endl; + if (this->options.profile_initialization) { + std::cout << " " << "Init Runtime: " << result.initialization_time_ms << " ms" << std::endl; + } + + if (this->options.output_file.good()) { + this->options.output_file << this->options.output_tag << ",CUTLASS,grouped-" << sched_mode << "," + << this->problem_count() << "," << result.runtime_ms << "," << result.gflops << std::endl; + } + + std::cout << "\nPassed\n"; + + return result; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + cudaDeviceProp props; + + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if (__CUDACC_VER_MAJOR__ < 11 || props.major < 8) { + + // + // This example requires an NVIDIA Ampere-architecture GPU. + // + + std::cout + << "CUTLASS's Grouped Rank2K example requires a GPU of NVIDIA's Ampere Architecture or " + << "later (compute capability 80 or greater).\n"; + + return 0; + } + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + + // + // Define the Grouped and Conventional Rank2K types + // + + using ElementA = double; + using ElementB = double; + using ElementOutput = double; + using ElementAccumulator = double; + const cutlass::FillMode kFillModeC = cutlass::FillMode::kLower; + const int kAlignmentA = 1; + const int kAlignmentB = 1; + const cutlass::ComplexTransform kTransformA = cutlass::ComplexTransform::kNone; + const cutlass::ComplexTransform kTransformB = cutlass::ComplexTransform::kNone; + + using LayoutA = cutlass::layout::ColumnMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + using OperatorClass = cutlass::arch::OpClassTensorOp; + using ArchTag = cutlass::arch::Sm80; + + using ThreadblockShape = cutlass::gemm::GemmShape<32, 32, 16>; + using WarpShape = cutlass::gemm::GemmShape<16, 16, 16>; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + + using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, 1, + ElementAccumulator, ElementAccumulator>; + + // NOTE: Threadblock swizzling is currently not supported by CUTLASS's grouped kernels. + // This parameter is passed in at present to match the APIs of other kernels. The parameter + // is unused within the kernel. + using ThreadblockSwizzle = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; + + const int kStages = 4; + const bool kSplitKSerial = false; + using Operator = cutlass::arch::OpMultiplyAdd; + const cutlass::BlasMode kBlasMode = cutlass::BlasMode::kSymmetric; + + // Define a grouped Rank2K kernel with all template parameters set except + // for scheduling mode. This will be used as the template for all scheduling + // modes executed. + using Rank2Kkernel = typename cutlass::gemm::kernel::DefaultRank2KGrouped< + ElementA, LayoutA, kTransformA, kAlignmentA, + ElementB, LayoutB, kTransformB, kAlignmentB, + ElementOutput, LayoutC, kFillModeC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + kStages, + Operator, + kBlasMode>::Rank2Kkernel; + + using Rank2KGrouped = cutlass::gemm::device::Rank2KGrouped; + + // Rank2k operator + using Rank2KConventional = cutlass::gemm::device::Rank2K< + ElementA, LayoutA, + ElementB, LayoutB, + ElementOutput, LayoutC, kFillModeC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + kStages, + kAlignmentA, + kAlignmentB, + kSplitKSerial, + Operator, + kTransformA, + kTransformB, + kBlasMode + >; + + // + // Profile it + // + + TestbedConventional testbed(options); + + Result result = testbed.profile(); + if (!result.passed) { + std::cout << "Profiling CUTLASS conventional Rank2K has failed.\n"; + std::cout << "\nFailed\n"; + return -1; + } + + using GroupScheduleMode = cutlass::gemm::kernel::GroupScheduleMode; + for (GroupScheduleMode mode : options.scheduler_modes) { + Result result; + switch (mode) { + case GroupScheduleMode::kDeviceOnly: + { + TestbedGrouped runner(options); + result = runner.profile(); + break; + } + case GroupScheduleMode::kHostPrecompute: + { + TestbedGrouped runner(options); + result = runner.profile(); + break; + } + } + + if (result.error != cudaSuccess) { + return 1; + } + + // Override verbose flag to avoid printing duplicate information for each scheduling mode + options.verbose = false; + } + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/39_gemm_permute/CMakeLists.txt b/examples/39_gemm_permute/CMakeLists.txt new file mode 100644 index 0000000000..dd916fdf5d --- /dev/null +++ b/examples/39_gemm_permute/CMakeLists.txt @@ -0,0 +1,36 @@ + +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + + +cutlass_example_add_executable( + 39_gemm_permute + gemm_permute.cu + ) + diff --git a/examples/39_gemm_permute/gemm_permute.cu b/examples/39_gemm_permute/gemm_permute.cu new file mode 100644 index 0000000000..3651b9c568 --- /dev/null +++ b/examples/39_gemm_permute/gemm_permute.cu @@ -0,0 +1,1223 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief GEMM Permute Example. + + This example computes batched GEMM operations with output results permuted as reshaped tensors. + + We provide layout plugin as a flexible tool for users to add any customized input/output tensor permute operation, + or any other generalized global memory writeout address computation. To add a customized layout, add new class + in include/cutlass/layout/permute.h + + In this example we use several permute operations (permute([0, 2, 1, 3])) + + In this example, we used Tensor4DPermuteBMM0213 layout to perform Batched GEMM with permute([0, 2, 1, 3]) on BMM + whole output tensor, and used Tensor5DPermute20314 layout to perform Normal GEMM with permute([2, 0, 3, 1, 4]) on + output matrix. The address computations are performed in compute(col_init, row_init, stride_init, + BMM_batch_idx) with {col_permute, row_permute and stride_permute} as new addresses after permute op. + (check include/cutlass/layout/permute.h) + + Tips: + + 1) Make sure to set batch_stride to zero for BMM permute; also the BMM GEMM should be in mode + cutlass::gemm::GemmUniversalMode::kBatched instead of kArray. + + 2) When the contiguous dimension is touched in permute op (for example [0, 2, 3, 1] for row-major matrix + or [1, 0, 2, 3] for column-major), Alignment should be set to 1 for the corresponding matrix. + If the last dimension is untouched, one can set Alignment to be larger like 8 in our example. + As a result, permute op without touching the unit stride dimension is recommended to obtain the best performance. + + Examples: + + # Runs a batched GEMM with 96 batches + $ ./examples/39_gemm_permute/39_gemm_permute --problem-count=96 + + # Runs a batched GEMM with 96 batches (with GEMM-K dimension equal to 1024) + $ ./examples/39_gemm_permute/39_gemm_permute --problem-count=96 --k=1024 --verbose=true + + # Execute batched GEMM and profile with NSight + $ nv-nsight-cu-cli ./examples/39_gemm_permute/39_gemm_permute --m=256 --n=192 --k=256 --verbose=true --iterations=1 --reference-check=false + +*/ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include +#include +#include +#include +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/device/gemm_universal.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm_complex.h" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_norm.h" + +#include "cutlass/layout/permute.h" + +#include "layouts.h" +#include "permute_info.h" + +/// Tensor4DPermuteBMM0213 ---> +/// Permute layout function for 4-D permuted tensors for BMM with BMM tensor (dimension as [B, M, N]) reshaped +/// as [B/D1, D1, M, N]. Then perform permute([0, 2, 1, 3]) on the corresponding whole BMM tensor. +int constexpr D1 = 12; + +/// Tensor5DPermute20314 ---> +/// Permute layout function for 5-D permuted tensors with matrix (dimension as [M, N]) reshaped +/// as [M/T1, T1, T2, T3, N/T2/T3]. Then perform permute([2, 0, 3, 1, 4]) on the corresponding tensor. +int constexpr T1 = 16; +int constexpr T2 = 3; +int constexpr T3 = 8; + +/// Tensor4DPermute0213 ---> +/// Permute layout function for 4-D permuted tensors with matrix (dimension as [M, N]) reshaped +/// as [M/S1, S1, S2, N/S2]. Then perform permute([0, 2, 1, 3]) on the corresponding tensor. +int constexpr S1 = 8; +int constexpr S2 = 4; + +// // // Alignments +int constexpr AlignmentA = 8; +int constexpr AlignmentB = 8; +int constexpr AlignmentC = 8; + +/// GEMM element types +using ElementInput = cutlass::half_t; +using ElementOutput = cutlass::half_t; +using ElementAccumulator = float; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Useful macros + +#define CHECK_CUDA_CALL(call, handler) \ +do { \ + cudaError_t __err = (call); \ + if (__err != cudaSuccess) { \ + std::cerr << #call " failed: " << cudaGetErrorString(__err) << std::endl; \ + handler; \ + } \ +} while(0) + +#define CHECK_CUTLASS_CALL(call, handler) \ +do { \ + cutlass::Status __status = (call); \ + if (__status != cutlass::Status::kSuccess) { \ + std::cerr << #call " failed: " << cutlass::cutlassGetStatusString(__status) << std::endl; \ + handler; \ + } \ +} while(0) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + bool error; + bool reference_check; + + cutlass::gemm::GemmCoord problem_each; + + int batch_count; + int iterations; + int cuda_streams; + bool verbose; + float alpha; + float beta; + + // + // Methods + // + + Options(): + help(false), + error(false), + reference_check(true), + batch_count(-1), + iterations(20), + cuda_streams(0), + verbose(false), + alpha(1), + beta() + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("alpha", alpha, 1.0f); + cmd.get_cmd_line_argument("beta", beta, 0.0f); + cmd.get_cmd_line_argument("iterations", iterations, 20); + cmd.get_cmd_line_argument("streams", cuda_streams, 0); + cmd.get_cmd_line_argument("verbose", verbose, false); + cmd.get_cmd_line_argument("reference-check", reference_check, true); + + int m, n, k; + + cmd.get_cmd_line_argument("m", m, 384); + cmd.get_cmd_line_argument("n", n, 192); + cmd.get_cmd_line_argument("k", k, 384); + cmd.get_cmd_line_argument("batch-count", batch_count, 96); + + problem_each = cutlass::gemm::GemmCoord(m, n, k); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << + "39_gemm_permute\n" + "\n" + " This example tests and profiles the performance of normal GEMM and batched GEMM with different" + " combinations of fused permutations of input and output tensors." + "\n" + " Permutations considered in this example:\n" + "\n" + " Normal GEMM:\n" + " 1) Tensor4DPermute0213: matrix of shape [X, Y] is reshaped as [X/S1, S1, S2, Y/S2] and has its dimensions" + " permuted as [0, 2, 1, 3], resulting in shape [X/S1, S2, S1, Y/S2] viewed as matrix of shape [X*S2/S1, Y*S1/S2].\n" + " 2) Tensor5DPermute20314: matrix of shape [X, Y] is reshaped as [X/T1, T1, T2, T3, Y/T2/T3] and has its dimensions" + " permuted as [2, 0, 3, 1, 4], resulting in shape [T2, X/T1, T3, T1, Y/T2/T3] viewed as matrix of shape [X*T2/T1, Y*T1/T2].\n" + "\n" + " Batched GEMM:\n" + " 3) Tensor4DPermuteBMM0213: batched tensor of 3D shape [B, X, Y] is reshaped as 4D shape [B/D1, D1, X, Y]" + " and has its dimensions permuted as [0, 2, 1, 3], resulting in shape [B/D1, X, D1, Y] viewed as" + " a matrix of shape [B/D1, X, Y*D1] for batched GEMM purposes.\n" + "\n" + " Note: S1, S2, D1, D2, T1, T2, T3 are compile-time constants defined in gemm_permute.cu." + " Runtime specification of these values is not supported." + " These values along with alignment requirements place constraints on supported matrix sizes.\n" + "\n" + " Note: X, Y above may refer to M, N or K dimensions of GEMM problem, depending on the tensor considered (A, B or D)." + " For the output tensor D the values correspond directly to dimensions of D, whereas for A and B the original dimensions" + " X', Y' are inferred from the ones supplied to the GEMM, taking into account the permute operation.\n" + "\n" + "Options:\n" + "\n" + " --help If specified, displays this usage statement.\n\n" + " --batch-count= Sets the number of batches in batched GEMM (batch number for BMM). (default: --batch-count=768)\n" + " --m= Sets the M dimension for both batched GEMM and normal GEMM problems. (default: --m=128)\n" + " --n= Sets the N dimension for both batched GEMM and normal GEMM problems. (default: --n=192)\n" + " --k= Sets the K dimension for both batched GEMM and normal GEMM problems. (default: --k=384)\n" + " --alpha= Epilogue scalar alpha (real part)\n" + " --beta= Epilogue scalar beta (real part)\n\n" + " --iterations= Number of profiling iterations to perform.\n" + " --reference-check= If true, performs reference check.\n" + " --verbose= If true, prints problem sizes and batching structure.\n" + "\n" + "Examples:\n" + "\n" + "# Runs a batched GEMM with 96 batches\n" + "$ ./examples/39_gemm_permute/39_gemm_permute --batch-count=96\n" + "\n" + "# Runs a batched GEMM with 96 batches (with GEMM-K dimension equal to 1024)\n" + "$ ./examples/39_gemm_permute/39_gemm_permute --batch-count=96 --k=1024 --verbose=true\n" + "\n" + "# Execute batched GEMM and profile with NSight\n" + "$ nv-nsight-cu-cli ./examples/39_gemm_permute/39_gemm_permute --m=256 --n=192 --k=256 --verbose=true --iterations=1 --reference-check=false\n" + "\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s, bool batched) const { + + // Number of real-valued multiply-adds + int64_t fmas = int64_t(); + + fmas += problem_each.product() * (batched ? batch_count : 1); + + // Two flops per multiply-add + return 2.0 * double(fmas) / double(1.0e9) / runtime_s; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace { // (anonymous) + +/// Dimension-generic permutation loop +template +void permute_host_impl( + cutlass::TensorView const & input, + cutlass::TensorView const & output, + PermuteOp && permute, + Coord & coord +) { + static_assert(Layout::kRank == Coord::kRank, "Incompatible Layout and Coord types"); + if constexpr (I == Coord::kRank) { + output.at(permute(coord)) = input.at(coord); + } + else { + for (coord[I] = 0; coord[I] < input.extent(I); ++coord[I]) { + permute_host_impl(input, output, std::forward(permute), coord); + } + } +} + +} // namespace (anonymous) + +/// Perform a reference (host-based) permutation of an input tensor +template +void permute_host( + cutlass::TensorView const &input, + cutlass::TensorView const &output, + int batch_count) { + Layout layout = input.layout(); + cutlass::MatrixCoord extent = input.extent(); + + std::size_t num_elems = layout.capacity(extent) * batch_count; + std::vector h_input(num_elems); + cutlass::device_memory::copy_to_host(h_input.data(), input.data(), num_elems); + + std::vector h_output(num_elems); + + using Info = PermuteInfo; + using TensorLayout = typename Info::Layout; + + auto shape_orig = Info::original_shape(extent, batch_count); + auto shape_perm = Info::permute(shape_orig); + + cutlass::TensorView view_input(h_input.data(), TensorLayout::packed(shape_orig), shape_orig); + cutlass::TensorView view_output(h_output.data(), TensorLayout::packed(shape_perm), shape_perm); + + decltype(shape_orig) coord; + permute_host_impl<0>(view_input, view_output, Info::permute, coord); + + cutlass::device_memory::copy_to_device(output.data(), h_output.data(), num_elems); +} + +template +struct LayoutInfo; + +template<> +struct LayoutInfo { + static std::string name() { return "RowMajor"; } +}; + +template<> +struct LayoutInfo { + static std::string name() { return "ColumnMajor"; } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class Testbed { +private: + + // + // Data members + // + + Options & options; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint32_t seed; + + cutlass::DeviceAllocation block_A; + cutlass::DeviceAllocation block_B; + cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_D; + +public: + + // + // Methods + // + + Testbed( + Options &options_, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint32_t seed_ = 3090 + ): + options(options_), init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + +private: + + /// Print permutation info for one tensor + template + void print_tensor_info( + std::ostream & os, + std::string const &tensor_name, + int row_dim, + int col_dim) { + + cutlass::MatrixCoord extent(options.problem_each.at(row_dim), options.problem_each.at(col_dim)); + using Info = PermuteInfo; + + os << "tensor " << tensor_name << ": " << Info::desc() << "\n"; + os << " extent: [" << extent.row() << ", " << extent.column() << "]"; + if (Info::kBatched) { + os << ", batch count: " << options.batch_count; + } + os << "\n"; + if (!cutlass::layout::is_trivial_permute) { + auto shape_orig = Info::original_shape(extent, options.batch_count); + auto shape_perm = Info::permute(shape_orig); + os << " original: [" << shape_orig << "]\n"; + os << " permuted: [" << shape_perm << "]\n"; + } + } + + /// Check shape compatibility for one tensor + template + bool check_tensor_shape( + std::string const &tensor_name, + int row_dim, + int col_dim) { + + cutlass::MatrixCoord extent(options.problem_each.at(row_dim), options.problem_each.at(col_dim)); + + using Info = PermuteInfo; + + auto rowAlign = cutlass::platform::is_same::value ? Alignment : 1; + auto colAlign = cutlass::platform::is_same::value ? Alignment : 1; + + auto rowFactor = Info::kRowFactor * rowAlign; + auto colFactor = Info::kColumnFactor * colAlign; + + // Assumes row-major layout + bool const valid_row = extent.row() % rowFactor == 0; + if (!valid_row) { + std::cerr << "Invalid tensor " << tensor_name << " row size = " << extent.row() << ", " + "must be divisible by " << rowFactor << ", " + "required by " << Info::name() << + (rowAlign > 1 ? (" and alignment of " + std::to_string(rowAlign)) : "") << std::endl; + } + + bool const valid_col = extent.column() % colFactor == 0; + if (!valid_col) { + std::cerr << "Invalid tensor " << tensor_name << " column size = " << extent.column() << ", " + "must be divisible by " << colFactor << ", " + "required by " << Info::name() << + (colAlign > 1 ? (" and alignment of " + std::to_string(colAlign)) : "") << std::endl; + } + + bool const valid_bsz = options.batch_count % Info::kBatchFactor == 0; + if (!valid_bsz) { + std::cerr << "Invalid batch count = " << options.batch_count << ", " + "must be divisible by " << Info::kBatchFactor << ", " + "required by " << Info::name() << std::endl; + } + + return valid_row && valid_col && valid_bsz; + } + + /// Helper to initialize a tensor view + template + void initialize_tensor_( + Element *ptr, + size_t capacity, + cutlass::Distribution::Kind dist_kind, + uint32_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + Element scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + if (cutlass::sizeof_bits::value <= 16) { + scope_max = 5; + scope_min = -5; + } + else { + scope_max = 8; + scope_min = -8; + } + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::device::BlockFillRandomUniform( + ptr, capacity, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::device::BlockFillRandomGaussian( + ptr, capacity, seed, Element(), Element(0.5f)); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + // Fill with increasing elements + cutlass::reference::device::BlockFillSequential( + ptr, capacity, Element(1), Element()); + } + else { + + // Fill with all 1s + cutlass::reference::device::BlockFillSequential( + ptr, capacity, Element(), Element(1)); + } + } + + /// Initializes data structures + void initialize(int batch_count) { + + srand(seed); + + int64_t total_elements_A = options.problem_each.m() * options.problem_each.k() * batch_count; + int64_t total_elements_B = options.problem_each.n() * options.problem_each.k() * batch_count; + int64_t total_elements_C = options.problem_each.m() * options.problem_each.n() * batch_count; + int64_t total_elements_D = options.problem_each.m() * options.problem_each.n() * batch_count; + + // Allocate space + block_A.reset(total_elements_A); + block_B.reset(total_elements_B); + block_C.reset(total_elements_C); + block_D.reset(total_elements_D); + + // Initialize input tensors + initialize_tensor_(block_A.get(), total_elements_A, init_A, seed * 2021); + initialize_tensor_(block_B.get(), total_elements_B, init_B, seed * 2022); + initialize_tensor_(block_C.get(), total_elements_C, init_C, seed * 2023); + + cutlass::reference::device::BlockFillSequential( + block_D.get(), total_elements_D, ElementC(), ElementC()); + } + + + /// Check device GEMM results against a reference implementation with separate host-based permutation + template + bool validate(Gemm const &gemm) { + + bool constexpr kBatched = PermuteInfo::kBatched + || PermuteInfo::kBatched + || PermuteInfo::kBatched; + + int const batch_count = kBatched ? options.batch_count : 1; + + cutlass::gemm::GemmCoord problem = options.problem_each; + + cutlass::MatrixCoord extent_A{problem.m(), problem.k()}; + cutlass::MatrixCoord extent_B{problem.k(), problem.n()}; + cutlass::MatrixCoord extent_C{problem.m(), problem.n()}; + + using LayoutA = typename Gemm::LayoutA; + using LayoutB = typename Gemm::LayoutB; + using LayoutC = typename Gemm::LayoutC; + + LayoutA layout_A(LayoutA::packed(extent_A)); + LayoutB layout_B(LayoutB::packed(extent_B)); + LayoutC layout_C(LayoutC::packed(extent_C)); + + auto size_A = layout_A.capacity(extent_A) * batch_count; + auto size_B = layout_B.capacity(extent_B) * batch_count; + auto size_C = layout_C.capacity(extent_C) * batch_count; + + cutlass::TensorView view_A(block_A.get(), layout_A, extent_A); + cutlass::TensorView view_B(block_B.get(), layout_B, extent_B); + cutlass::TensorView view_C(block_C.get(), layout_C, extent_C); + cutlass::TensorView view_D(block_D.get(), layout_C, extent_C); + + cutlass::DeviceAllocation block_A_perm(size_A); + cutlass::DeviceAllocation block_B_perm(size_B); + + cutlass::TensorView view_A_perm(block_A_perm.get(), layout_A, extent_A); + cutlass::TensorView view_B_perm(block_B_perm.get(), layout_B, extent_B); + + permute_host(view_A.const_view(), view_A_perm, batch_count); + permute_host(view_B.const_view(), view_B_perm, batch_count); + + cutlass::DeviceAllocation block_D_ref(size_C); + cutlass::TensorView view_D_ref(block_D_ref.get(), layout_C, extent_C); + + using EpilogueOutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp; + + // Reference GEMM + cutlass::reference::device::GemmComplex< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + typename EpilogueOutputOp::ElementCompute, + typename Gemm::ElementAccumulator + >( + problem, + options.alpha, + view_A_perm, + Gemm::kTransformA, + view_B_perm, + Gemm::kTransformB, + options.beta, + view_C, + view_D_ref, + ElementAccumulator(0), + batch_count, + options.problem_each.m() * options.problem_each.k(), + options.problem_each.n() * options.problem_each.k(), + options.problem_each.m() * options.problem_each.n(), + options.problem_each.m() * options.problem_each.n() + ); + + cutlass::DeviceAllocation block_D_perm(size_C); + cutlass::TensorView view_D_perm(block_D_perm.get(), layout_C, extent_C); + permute_host(view_D_ref.const_view(), view_D_perm, batch_count); + + // Reference check + return cutlass::reference::device::BlockCompareEqual(view_D_perm.data(), view_D.data(), size_C); +} + +public: + + template + bool profile_GEMM_permute() { + + using LayoutA = typename Gemm::LayoutA; + using LayoutB = typename Gemm::LayoutB; + using LayoutC = typename Gemm::LayoutC; + + using PermuteALayout = typename Gemm::PermuteALayout; + using PermuteBLayout = typename Gemm::PermuteBLayout; + using PermuteDLayout = typename Gemm::PermuteDLayout; + + bool constexpr kBatched = PermuteInfo::kBatched + || PermuteInfo::kBatched + || PermuteInfo::kBatched; + + std::cout << "\n" + "====================================================\n" + << (kBatched ? "Batched" : "Normal") << " GEMM:" + << "\n A=" << LayoutInfo::name() << "," << PermuteInfo::name() + << "\n B=" << LayoutInfo::name() << "," << PermuteInfo::name() + << "\n D=" << LayoutInfo::name() << "," << PermuteInfo::name() + << "\n" + "====================================================\n"; + + if (options.verbose) { + print_tensor_info(std::cout, "A", 0, 2); + print_tensor_info(std::cout, "B", 2, 1); + print_tensor_info(std::cout, "D", 0, 1); + } + std::cout << std::endl; + + bool valid = true; + valid &= check_tensor_shape("A", 0, 2); + valid &= check_tensor_shape("B", 2, 1); + valid &= check_tensor_shape("D", 0, 1); + if (!valid) + { + std::cout << "Skipped test" << std::endl; + return true; + } + + int const batch_count = kBatched ? options.batch_count : 1; + + // Initialize the problem + initialize(batch_count); + + // Configure the GEMM arguments + using EpilogueOutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp; + typename EpilogueOutputOp::Params epilogue_op(options.alpha, options.beta); + + // Please make sure all problem_sizes are the same for kBatched mode + auto problem = options.problem_each; + + cutlass::MatrixCoord extent_A{problem.m(), problem.k()}; + cutlass::MatrixCoord extent_B{problem.k(), problem.n()}; + cutlass::MatrixCoord extent_C{problem.m(), problem.n()}; + + LayoutA layout_A(LayoutA::packed(extent_A)); + LayoutB layout_B(LayoutB::packed(extent_B)); + LayoutC layout_C(LayoutC::packed(extent_C)); + + // Configure GEMM arguments + typename Gemm::Arguments arguments{ + kBatched ? cutlass::gemm::GemmUniversalMode::kBatched : cutlass::gemm::GemmUniversalMode::kGemm, + problem, + batch_count, + epilogue_op, + (void*)block_A.get(), + (void*)block_B.get(), + (void*)block_C.get(), + (void*)block_D.get(), + // For any non-trivial permute the batch stride must be set to 0 + cutlass::layout::is_trivial_permute ? layout_A.capacity(extent_A) : 0, + cutlass::layout::is_trivial_permute ? layout_B.capacity(extent_B) : 0, + layout_C.capacity(extent_C), + cutlass::layout::is_trivial_permute ? layout_C.capacity(extent_C) : 0, + layout_A.stride(0), + layout_B.stride(0), + layout_C.stride(0), + layout_C.stride(0), + }; + + // Initialize the GEMM object + Gemm gemm_normal; + + CHECK_CUTLASS_CALL(gemm_normal.initialize(arguments, nullptr), return false); + + // Run the normal GEMM object + CHECK_CUTLASS_CALL(gemm_normal.run(), return false); + + // Wait for completion + CHECK_CUDA_CALL(cudaDeviceSynchronize(), return false); + + // + // Verify correctness + // + if (options.reference_check) { + if (validate(gemm_normal)) { + std::cout << "\nPassed verification\n" << std::endl; + } + else { + std::cerr << "\n*** Error - problem failed the QA check ***\n" << std::endl; + return false; + } + } + + // Warm-up run of the normal GEMM object + CHECK_CUTLASS_CALL(gemm_normal.run(), return false); + + // Construct events + cudaEvent_t events[2]; + for (auto & event : events) { + CHECK_CUDA_CALL(cudaEventCreate(&event), return false); + } + + // Record an event at the start of a series of GEMM operations + CHECK_CUDA_CALL(cudaEventRecord(events[0]), return false); + + // Run profiling loop + for (int iter = 0; iter < options.iterations; ++iter) { + gemm_normal(); + } + + // Record an event when the GEMM operations have been launched. + CHECK_CUDA_CALL(cudaEventRecord(events[1]), return false); + + // Wait for work on the device to complete. + CHECK_CUDA_CALL(cudaEventSynchronize(events[1]), return false); + + // Measure elapsed runtime + float runtime_total_ms = 0; + CHECK_CUDA_CALL(cudaEventElapsedTime(&runtime_total_ms, events[0], events[1]), return false); + + // Compute average runtime and GFLOPs. + double runtime_avg_ms = double(runtime_total_ms) / double(options.iterations); + double gflops = options.gflops(runtime_avg_ms / 1000.0, kBatched); + + // Cleanup + for (auto event : events) { + CHECK_CUDA_CALL(cudaEventDestroy(event), return false); + } + + std::cout << " Runtime: " << runtime_avg_ms << " ms\n" + " GFLOPs: " << gflops << std::endl; + + return true; + } +}; + +/// Shorthand alist for GEMM instantiations +template +using GemmPermute = cutlass::gemm::device::GemmUniversal< + ElementInput, LayoutA, + ElementInput, LayoutB, + ElementOutput, LayoutC, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + cutlass::gemm::GemmShape<128, 128, 32>, + cutlass::gemm::GemmShape<64, 64, 32>, + cutlass::gemm::GemmShape<16, 8, 16>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, + AlignmentC, //128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementAccumulator + >, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + 4, /*kStages*/ + AlignmentA, /*AlignmentA*/ + AlignmentB, /*AlignmentB*/ + cutlass::arch::OpMultiplyAdd, + cutlass::ComplexTransform::kNone, + cutlass::ComplexTransform::kNone, + false, /*GatherA*/ + false, /*GatherB*/ + false, /*ScatterD*/ + PermuteDLayout, /*PermuteDLayout*/ + typename cutlass::layout::InversePermute::type, /*PermuteALayout*/ + typename cutlass::layout::InversePermute::type /*PermuteBLayout*/ +>; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // + // This example uses mma.sync to directly access Tensor Cores to achieve peak performance. + // + + cudaDeviceProp props; + + CHECK_CUDA_CALL(cudaGetDeviceProperties(&props, 0), return EXIT_FAILURE); + + if (__CUDACC_VER_MAJOR__ < 11 || props.major < 8) { + + // + // This example requires an NVIDIA Ampere-architecture GPU. + // + + std::cout << "CUTLASS's GEMM+Permute example requires a GPU of NVIDIA's Ampere Architecture " + "or later (compute capability 80 or greater).\n"; + + return EXIT_SUCCESS; + } + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return EXIT_SUCCESS; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return EXIT_FAILURE; + } + + // + // Define GEMM types to test + // + + // + // TTT (Row-major) GEMMs + // + + using TTTGemmNormalPermuteNone = GemmPermute< + cutlass::layout::RowMajor, cutlass::layout::NoPermute, + cutlass::layout::RowMajor, cutlass::layout::NoPermute, + cutlass::layout::RowMajor, cutlass::layout::NoPermute + >; + + using TTTGemmNormalPermuteA = GemmPermute< + cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor, + cutlass::layout::RowMajor, cutlass::layout::NoPermute, + cutlass::layout::RowMajor, cutlass::layout::NoPermute + >; + + using TTTGemmNormalPermuteAD = GemmPermute< + cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor, + cutlass::layout::RowMajor, cutlass::layout::NoPermute, + cutlass::layout::RowMajor, cutlass::layout::Tensor5DPermute20314RowMajor + >; + + using TTTGemmNormalPermuteB = GemmPermute< + cutlass::layout::RowMajor, cutlass::layout::NoPermute, + cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor, + cutlass::layout::RowMajor, cutlass::layout::NoPermute + >; + + using TTTGemmNormalPermuteBD = GemmPermute< + cutlass::layout::RowMajor, cutlass::layout::NoPermute, + cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor, + cutlass::layout::RowMajor, cutlass::layout::Tensor5DPermute20314RowMajor + >; + + using TTTGemmNormalPermuteD = GemmPermute< + cutlass::layout::RowMajor, cutlass::layout::NoPermute, + cutlass::layout::RowMajor, cutlass::layout::NoPermute, + cutlass::layout::RowMajor, cutlass::layout::Tensor5DPermute20314RowMajor + >; + + using TTTGemmNormalPermuteAB = GemmPermute< + cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor, + cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor, + cutlass::layout::RowMajor, cutlass::layout::NoPermute + >; + + using TTTGemmNormalPermuteABD = GemmPermute< + cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor, + cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor, + cutlass::layout::RowMajor, cutlass::layout::Tensor5DPermute20314RowMajor + >; + + // + // NNN (Col-major) GEMMs + // + + using NNNGemmNormalPermuteNone = GemmPermute< + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute + >; + + using NNNGemmNormalPermuteA = GemmPermute< + cutlass::layout::ColumnMajor, cutlass::layout::Tensor5DPermute02413ColumnMajor, + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute + >; + + using NNNGemmNormalPermuteAD = GemmPermute< + cutlass::layout::ColumnMajor, cutlass::layout::Tensor5DPermute02413ColumnMajor, + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, + cutlass::layout::ColumnMajor, cutlass::layout::Tensor5DPermute02413ColumnMajor + >; + + using NNNGemmNormalPermuteB = GemmPermute< + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, + cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermute0213ColumnMajor, + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute + >; + + using NNNGemmNormalPermuteBD = GemmPermute< + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, + cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermute0213ColumnMajor, + cutlass::layout::ColumnMajor, cutlass::layout::Tensor5DPermute02413ColumnMajor + >; + + using NNNGemmNormalPermuteD = GemmPermute< + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, + cutlass::layout::ColumnMajor, cutlass::layout::Tensor5DPermute02413ColumnMajor + >; + + using NNNGemmNormalPermuteAB = GemmPermute< + cutlass::layout::ColumnMajor, cutlass::layout::Tensor5DPermute02413ColumnMajor, + cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermute0213ColumnMajor, + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute + >; + + using NNNGemmNormalPermuteABD = GemmPermute< + cutlass::layout::ColumnMajor, cutlass::layout::Tensor5DPermute02413ColumnMajor, + cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermute0213ColumnMajor, + cutlass::layout::ColumnMajor, cutlass::layout::Tensor5DPermute02413ColumnMajor + >; + + // + // NNT (Col-major inputs, row-major output) GEMMs + // + + using NNTGemmNormalPermuteNone = GemmPermute< + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, + cutlass::layout::RowMajor, cutlass::layout::NoPermute + >; + + using NNTGemmNormalPermuteA = GemmPermute< + cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermute0213RowMajor, + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, + cutlass::layout::RowMajor, cutlass::layout::NoPermute + >; + + using NNTGemmNormalPermuteAD = GemmPermute< + cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermute0213RowMajor, + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, + cutlass::layout::RowMajor, cutlass::layout::Tensor5DPermute20314RowMajor + >; + + using NNTGemmNormalPermuteB = GemmPermute< + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, + cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermute0213ColumnMajor, + cutlass::layout::RowMajor, cutlass::layout::NoPermute + >; + + using NNTGemmNormalPermuteBD = GemmPermute< + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, + cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermute0213ColumnMajor, + cutlass::layout::RowMajor, cutlass::layout::Tensor5DPermute20314RowMajor + >; + + using NNTGemmNormalPermuteD = GemmPermute< + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, + cutlass::layout::RowMajor, cutlass::layout::Tensor5DPermute20314RowMajor + >; + + using NNTGemmNormalPermuteAB = GemmPermute< + cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermute0213RowMajor, + cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermute0213ColumnMajor, + cutlass::layout::RowMajor, cutlass::layout::NoPermute + >; + + using NNTGemmNormalPermuteABD = GemmPermute< + cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermute0213RowMajor, + cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermute0213ColumnMajor, + cutlass::layout::RowMajor, cutlass::layout::Tensor5DPermute20314RowMajor + >; + + // + // TTN (Row-major inputs, col-major output) GEMMs + // + + using TTNGemmNormalPermuteNone = GemmPermute< + cutlass::layout::RowMajor, cutlass::layout::NoPermute, + cutlass::layout::RowMajor, cutlass::layout::NoPermute, + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute + >; + + using TTNGemmNormalPermuteA = GemmPermute< + cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor, + cutlass::layout::RowMajor, cutlass::layout::NoPermute, + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute + >; + + using TTNGemmNormalPermuteAD = GemmPermute< + cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor, + cutlass::layout::RowMajor, cutlass::layout::NoPermute, + cutlass::layout::ColumnMajor, cutlass::layout::Tensor5DPermute02413ColumnMajor + >; + + using TTNGemmNormalPermuteB = GemmPermute< + cutlass::layout::RowMajor, cutlass::layout::NoPermute, + cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor, + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute + >; + + using TTNGemmNormalPermuteBD = GemmPermute< + cutlass::layout::RowMajor, cutlass::layout::NoPermute, + cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor, + cutlass::layout::ColumnMajor, cutlass::layout::Tensor5DPermute02413ColumnMajor + >; + + using TTNGemmNormalPermuteD = GemmPermute< + cutlass::layout::RowMajor, cutlass::layout::NoPermute, + cutlass::layout::RowMajor, cutlass::layout::NoPermute, + cutlass::layout::ColumnMajor, cutlass::layout::Tensor5DPermute02413ColumnMajor + >; + + using TTNGemmNormalPermuteAB = GemmPermute< + cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor, + cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor, + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute + >; + + using TTNGemmNormalPermuteABD = GemmPermute< + cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor, + cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermute0213RowMajor, + cutlass::layout::ColumnMajor, cutlass::layout::Tensor5DPermute02413ColumnMajor + >; + + // + // TTT (Row-major) BMMs + // + + using TTTGemmBatchedPermuteA = GemmPermute< + cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermuteBMM0213RowMajor, + cutlass::layout::RowMajor, cutlass::layout::NoPermute, + cutlass::layout::RowMajor, cutlass::layout::NoPermute + >; + + using TTTGemmBatchedPermuteAD = GemmPermute< + cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermuteBMM0213RowMajor, + cutlass::layout::RowMajor, cutlass::layout::NoPermute, + cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermuteBMM0213RowMajor + >; + + using TTTGemmBatchedPermuteB = GemmPermute< + cutlass::layout::RowMajor, cutlass::layout::NoPermute, + cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermuteBMM0213RowMajor, + cutlass::layout::RowMajor, cutlass::layout::NoPermute + >; + + using TTTGemmBatchedPermuteBD = GemmPermute< + cutlass::layout::RowMajor, cutlass::layout::NoPermute, + cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermuteBMM0213RowMajor, + cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermuteBMM0213RowMajor + >; + + using TTTGemmBatchedPermuteD = GemmPermute< + cutlass::layout::RowMajor, cutlass::layout::NoPermute, + cutlass::layout::RowMajor, cutlass::layout::NoPermute, + cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermuteBMM0213RowMajor + >; + + using TTTGemmBatchedPermuteAB = GemmPermute< + cutlass::layout::RowMajor, cutlass::layout::NoPermute, + cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermuteBMM0213RowMajor, + cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermuteBMM0213RowMajor + >; + + using TTTGemmBatchedPermuteABD = GemmPermute< + cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermuteBMM0213RowMajor, + cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermuteBMM0213RowMajor, + cutlass::layout::RowMajor, cutlass::layout::Tensor4DPermuteBMM0213RowMajor + >; + + // + // NNN (Col-major) BMMs + // + + using NNNGemmBatchedPermuteA = GemmPermute< + cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor, + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute + >; + + using NNNGemmBatchedPermuteAD = GemmPermute< + cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor, + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, + cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor + >; + + using NNNGemmBatchedPermuteB = GemmPermute< + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, + cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor, + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute + >; + + using NNNGemmBatchedPermuteBD = GemmPermute< + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, + cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor, + cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor + >; + + using NNNGemmBatchedPermuteD = GemmPermute< + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute, + cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor + >; + + using NNNGemmBatchedPermuteAB = GemmPermute< + cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor, + cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor, + cutlass::layout::ColumnMajor, cutlass::layout::NoPermute + >; + + using NNNGemmBatchedPermuteABD = GemmPermute< + cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor, + cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor, + cutlass::layout::ColumnMajor, cutlass::layout::Tensor4DPermuteBMM0321ColumnMajor + >; + + // + // Profile it + // + + Testbed testbed(options); + + bool result = true; + + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + result &= testbed.profile_GEMM_permute(); + + std::cout << "\n" + "====================================================\n" + "Finished (" << (result ? "PASS" : "FAIL") << ")\n" + "====================================================" << std::endl; + + return result ? EXIT_SUCCESS : EXIT_FAILURE; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/39_gemm_permute/layouts.h b/examples/39_gemm_permute/layouts.h new file mode 100644 index 0000000000..3632ec0afb --- /dev/null +++ b/examples/39_gemm_permute/layouts.h @@ -0,0 +1,506 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines additional layout functions used in Permute GEMM example to simplify + computing reference permutations of 4/5D tensors when source data is column-major. +*/ +#pragma once +#include +#include "cutlass/cutlass.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/coord.h" +#include "cutlass/tensor_coord.h" + +namespace cutlass { +namespace layout { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Mapping function for 4-D CWHN tensors. +class TensorCWHN { +public: + /// Logical rank of tensor + static int const kRank = 4; + + /// Rank of stride vector + static int const kStrideRank = 3; + + /// Index type used for coordinates + using Index = int32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Logical coordinate (n, h, w, c) + using TensorCoord = Tensor4DCoord; + + /// Stride vector + using Stride = Coord; + +private: + // + // Data members + // + + /// Stride data member - [n, hn, whn] + Stride stride_; + +public: + // + // Methods + // + + /// Constructor + CUTLASS_HOST_DEVICE + TensorCWHN(Stride const &stride = Stride(0)): stride_(stride) { } + + /// Constructor + CUTLASS_HOST_DEVICE + TensorCWHN( + typename Stride::Index stride_h, ///< number of elements between adjacent N coordinates + typename Stride::Index stride_w, ///< number of elements between adjacent C coordinates + typename Stride::Index stride_c ///< number of elements between adjacent W coordinates + ): + stride_(make_Coord(stride_h, stride_w, stride_c)) { } + + /// Constructor + // Once convolutions implement 64b stride this ctor can be deleted + CUTLASS_HOST_DEVICE + TensorCWHN(Coord const &stride): + stride_(make_Coord( + static_cast(stride[0]), + static_cast(stride[1]), + static_cast(stride[2])) + ) { } + + /// Helper returns a layout to a tightly packed WCNH tensor. + CUTLASS_HOST_DEVICE + static TensorCWHN packed(TensorCoord const &extent) { + return TensorCWHN( + make_Coord( + extent.n(), + extent.h() * extent.n(), + extent.w() * extent.h() * extent.n() + ) + ); + } + + /// Returns the offset of a coordinate (n, h, w, c) in linear memory. + CUTLASS_HOST_DEVICE + LongIndex operator()(TensorCoord const &coord) const { + return coord.n() + + LongIndex(stride_[0] * coord.h()) + + LongIndex(stride_[1] * coord.w()) + + LongIndex(stride_[2] * coord.c()); + } + + /// Returns the offset of a pitchlinear coordinate in linear memory. + CUTLASS_HOST_DEVICE + LongIndex operator()(PitchLinearCoord coord) const { + return coord.contiguous() + LongIndex(coord.strided() * stride_[2]); + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride stride() const { + return stride_; + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride & stride() { + return stride_; + } + + /// Compute the number of contiguous elements needed to store a tensor with the given size + CUTLASS_HOST_DEVICE + LongIndex capacity(TensorCoord const &extent) const { + // it does not make sense if the extent is larger than stride + // and we could not rely on the capacity calculation in such cases + // we could move this checkers to debug code only + if ((extent.n() > stride_[0]) + || (extent.h() * stride_[0] > stride_[1]) + || (extent.w() * stride_[1] > stride_[2])) { + assert(0); + } + return extent.c() * stride_[2]; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Mapping function for 4-D NHCW tensors. +class TensorNHCW { +public: + /// Logical rank of tensor + static int const kRank = 4; + + /// Rank of stride vector + static int const kStrideRank = 3; + + /// Index type used for coordinates + using Index = int32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Logical coordinate (n, h, w, c) + using TensorCoord = Tensor4DCoord; + + /// Stride vector + using Stride = Coord; + +private: + // + // Data members + // + + /// Stride data member - [w, cw, hcw] + Stride stride_; + +public: + // + // Methods + // + + /// Constructor + CUTLASS_HOST_DEVICE + TensorNHCW(Stride const &stride = Stride(0)): stride_(stride) { } + + /// Constructor + CUTLASS_HOST_DEVICE + TensorNHCW( + typename Stride::Index stride_c, ///< number of elements between adjacent C coordinates + typename Stride::Index stride_h, ///< number of elements between adjacent H coordinates + typename Stride::Index stride_n ///< number of elements between adjacent N coordinates + ): + stride_(make_Coord(stride_c, stride_h, stride_n)) { } + + /// Constructor + // Once convolutions implement 64b stride this ctor can be deleted + CUTLASS_HOST_DEVICE + TensorNHCW(Coord const &stride): + stride_(make_Coord( + static_cast(stride[0]), + static_cast(stride[1]), + static_cast(stride[2])) + ) { } + + /// Helper returns a layout to a tightly packed WCNH tensor. + CUTLASS_HOST_DEVICE + static TensorNHCW packed(TensorCoord const &extent) { + return TensorNHCW( + make_Coord( + extent.w(), + extent.c() * extent.w(), + extent.h() * extent.c() * extent.w() + ) + ); + } + + /// Returns the offset of a coordinate (n, h, w, c) in linear memory. + CUTLASS_HOST_DEVICE + LongIndex operator()(TensorCoord const &coord) const { + return coord.w() + + LongIndex(stride_[0] * coord.c()) + + LongIndex(stride_[1] * coord.h()) + + LongIndex(stride_[2] * coord.n()); + } + + /// Returns the offset of a pitchlinear coordinate in linear memory. + CUTLASS_HOST_DEVICE + LongIndex operator()(PitchLinearCoord coord) const { + return coord.contiguous() + LongIndex(coord.strided() * stride_[2]); + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride stride() const { + return stride_; + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride & stride() { + return stride_; + } + + /// Compute the number of contiguous elements needed to store a tensor with the given size + CUTLASS_HOST_DEVICE + LongIndex capacity(TensorCoord const &extent) const { + // it does not make sense if the extent is larger than stride + // and we could not rely on the capacity calculation in such cases + // we could move this checkers to debug code only + if ((extent.w() > stride_[0]) + || (extent.c() * stride_[0] > stride_[1]) + || (extent.h() * stride_[1] > stride_[2])) { + assert(0); + } + return extent.n() * stride_[2]; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Mapping function for 4-D NHCW tensors. +class TensorNCWH { +public: + /// Logical rank of tensor + static int const kRank = 4; + + /// Rank of stride vector + static int const kStrideRank = 3; + + /// Index type used for coordinates + using Index = int32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Logical coordinate (n, h, w, c) + using TensorCoord = Tensor4DCoord; + + /// Stride vector + using Stride = Coord; + +private: + // + // Data members + // + + /// Stride data member - [h, wh, cwh] + Stride stride_; + +public: + // + // Methods + // + + /// Constructor + CUTLASS_HOST_DEVICE + TensorNCWH(Stride const &stride = Stride(0)): stride_(stride) { } + + /// Constructor + CUTLASS_HOST_DEVICE + TensorNCWH( + typename Stride::Index stride_w, ///< number of elements between adjacent C coordinates + typename Stride::Index stride_c, ///< number of elements between adjacent H coordinates + typename Stride::Index stride_n ///< number of elements between adjacent N coordinates + ): + stride_(make_Coord(stride_w, stride_c, stride_n)) { } + + /// Constructor + // Once convolutions implement 64b stride this ctor can be deleted + CUTLASS_HOST_DEVICE + TensorNCWH(Coord const &stride): + stride_(make_Coord( + static_cast(stride[0]), + static_cast(stride[1]), + static_cast(stride[2])) + ) { } + + /// Helper returns a layout to a tightly packed WCNH tensor. + CUTLASS_HOST_DEVICE + static TensorNCWH packed(TensorCoord const &extent) { + return TensorNCWH( + make_Coord( + extent.h(), + extent.w() * extent.h(), + extent.c() * extent.w() * extent.h() + ) + ); + } + + /// Returns the offset of a coordinate (n, h, w, c) in linear memory. + CUTLASS_HOST_DEVICE + LongIndex operator()(TensorCoord const &coord) const { + return coord.h() + + LongIndex(stride_[0] * coord.w()) + + LongIndex(stride_[1] * coord.c()) + + LongIndex(stride_[2] * coord.n()); + } + + /// Returns the offset of a pitchlinear coordinate in linear memory. + CUTLASS_HOST_DEVICE + LongIndex operator()(PitchLinearCoord coord) const { + return coord.contiguous() + LongIndex(coord.strided() * stride_[2]); + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride stride() const { + return stride_; + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride & stride() { + return stride_; + } + + /// Compute the number of contiguous elements needed to store a tensor with the given size + CUTLASS_HOST_DEVICE + LongIndex capacity(TensorCoord const &extent) const { + // it does not make sense if the extent is larger than stride + // and we could not rely on the capacity calculation in such cases + // we could move this checkers to debug code only + if ((extent.h() > stride_[0]) + || (extent.w() * stride_[0] > stride_[1]) + || (extent.c() * stride_[1] > stride_[2])) { + assert(0); + } + return extent.n() * stride_[2]; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Mapping function for 5-D CWHDN tensors. +class TensorCWHDN { +public: + /// Logical rank of tensor + static int const kRank = 5; + + /// Rank of stride vector + static int const kStrideRank = 4; + + /// Index type used for coordinates + using Index = int32_t; + + /// Long index type used for offsets + using LongIndex = int64_t; + + /// Logical coordinate (n, d, h, w, c) + using TensorCoord = Tensor5DCoord; + + /// Stride vector + using Stride = Coord; + +private: + // + // Data members + // + + /// Stride data member - [n, dn, hdn, whdn] + Stride stride_; + +public: + // + // Methods + // + + /// Constructor + CUTLASS_HOST_DEVICE + TensorCWHDN(Stride const &stride = Stride(0)): stride_(stride) { } + + /// Constructor + CUTLASS_HOST_DEVICE + TensorCWHDN( + typename Stride::Index n, + typename Stride::Index dn, + typename Stride::Index hdn, + typename Stride::Index whdn): + stride_(make_Coord(n, dn, hdn, whdn)) { } + + /// Constructor + // Once convolutions implement 64b stride this ctor can be deleted + CUTLASS_HOST_DEVICE + TensorCWHDN(Coord const &stride): + stride_(make_Coord( + static_cast(stride[0]), + static_cast(stride[1]), + static_cast(stride[2]), + static_cast(stride[3])) + ) { } + + /// Helper returns a layout to a tightly packed CWHDN tensor. + CUTLASS_HOST_DEVICE + static TensorCWHDN packed(TensorCoord const &extent) { + return TensorCWHDN( + make_Coord( + extent.n(), + extent.d() * extent.n(), + extent.h() * extent.d() * extent.n(), + extent.w() * extent.h() * extent.d() * extent.n() + ) + ); + } + + /// Returns the offset of a coordinate (n, d, h, w, c) in linear memory. + CUTLASS_HOST_DEVICE + LongIndex operator()(TensorCoord const &coord) const { + return coord.n() + + LongIndex(stride_[0] * coord.d()) + + LongIndex(stride_[1] * coord.h()) + + LongIndex(stride_[2] * coord.w()) + + LongIndex(stride_[3] * coord.c()); + } + + /// Returns the offset of a pitchlinear coordinate in linear memory. + CUTLASS_HOST_DEVICE + LongIndex operator()(PitchLinearCoord coord) const { + return coord.contiguous() + LongIndex(coord.strided() * stride_[3]); + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride stride() const { + return stride_; + } + + /// Returns the stride of the layout + CUTLASS_HOST_DEVICE + Stride & stride() { + return stride_; + } + + /// Compute the number of contiguous elements needed to store a tensor with the given size + CUTLASS_HOST_DEVICE + LongIndex capacity(TensorCoord const &extent) const { + // it does not make sense if the extent is larger than stride + // and we could not rely on the capacity calculation in such cases + // we could move this checkers to debug code only + if ((extent.n() > stride_[0]) + || (extent.d() * stride_[0] > stride_[1]) + || (extent.h() * stride_[1] > stride_[2]) + || (extent.w() * stride_[2] > stride_[3])) { + assert(0); + } + return extent.c() * stride_[3]; + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace layout +} // namespace cutlass diff --git a/examples/39_gemm_permute/permute_info.h b/examples/39_gemm_permute/permute_info.h new file mode 100644 index 0000000000..57672e7c49 --- /dev/null +++ b/examples/39_gemm_permute/permute_info.h @@ -0,0 +1,344 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Contains additional metadata about layout permute functions used in the example. +*/ + +#include "cutlass/tensor_coord.h" +#include "cutlass/layout/permute.h" + +/// Additional permutation metadata to facilitate testing/printing +template +struct PermuteInfo; + +/// Specialization for default case (no permute). Other specializations must follow this template. +template<> +struct PermuteInfo { + + /// Whether this is a BMM or GEMM permutation (NoPermute can actually be either) + static bool constexpr kBatched = false; + + /// Minimal divisor for row extent + static int constexpr kRowFactor = 1; + + /// Minimum divisor for column extent + static int constexpr kColumnFactor = 1; + + /// Minimum divisor for batch size dimension + static int constexpr kBatchFactor = 1; + + /// Tensor layout used in permutation operation + using Layout = cutlass::layout::PackedVectorLayout; + + static std::string name() { + return "NoPermute"; + } + + /// User-friendly description of the permute operation + static std::string desc() { + return "no permutation"; + } + + /// Infer original higher-rank tensor shape from GEMM/BMM matrix extents. + /// For direct (output) permutations, must be a simple reshape of extent. + /// For inverse (input) permutations, must return shape *before* permute operation. + /// In case of NoPermute, simply use a linear (rank 1) view of the memory + static Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count) { + return Layout::TensorCoord(extent.row() * extent.column() * batch_count); + } + + /// Compute the permuted higher-rank tensor shape from the original shape. + static Layout::TensorCoord permute(Layout::TensorCoord const &s) { + return s; + } +}; + +template +struct PermuteInfo> { + + static bool constexpr kBatched = true; + static int constexpr kRowFactor = 1; + static int constexpr kColumnFactor = 1; + static int constexpr kBatchFactor = D1; + + using Layout = cutlass::layout::TensorNHWC; + + static std::string name() { + return "Tensor4DPermuteBMM0213<" + std::to_string(D1) + ">"; + } + + static std::string desc() { + return "batched GEMM permutation [0, 2, 1, 3]"; + } + + static Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count) { + int D0 = batch_count / D1; + int D2 = extent.row(); + int D3 = extent.column(); + return {D0, D1, D2, D3}; + } + + static Layout::TensorCoord permute(Layout::TensorCoord const &s) { + return {s[0], s[2], s[1], s[3]}; + } +}; + +template +struct PermuteInfo> +: public PermuteInfo> { + + static bool constexpr kBatched = true; + static int constexpr kRowFactor = 1; + static int constexpr kColumnFactor = D1; + static int constexpr kBatchFactor = 1; + + using Base = PermuteInfo>; + using Layout = typename Base::Layout; + + static typename Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count) { + int D0 = batch_count; + int D2 = extent.row(); + int D3 = extent.column() / D1; + return {D0, D1, D2, D3}; + } +}; + +template +struct PermuteInfo> { + + static bool constexpr kBatched = true; + static int constexpr kRowFactor = 1; + static int constexpr kColumnFactor = 1; + static int constexpr kBatchFactor = D1; + + using Layout = cutlass::layout::TensorNHCW; + + static std::string name() { + return "Tensor4DPermuteBMM0321<" + std::to_string(D1) + ">"; + } + + static std::string desc() { + return "batched GEMM permutation [0, 3, 2, 1]"; + } + + static Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count) { + int D0 = batch_count / D1; + int D2 = extent.row(); + int D3 = extent.column(); + return {D0, D1, D2, D3}; + } + + static Layout::TensorCoord permute(Layout::TensorCoord const &s) { + return {s[0], s[3], s[2], s[1]}; + } +}; + +template +struct PermuteInfo> +: public PermuteInfo> { + + static bool constexpr kBatched = true; + static int constexpr kRowFactor = D1; + static int constexpr kColumnFactor = 1; + static int constexpr kBatchFactor = 1; + + using Base = PermuteInfo>; + using Layout = typename Base::Layout; + + static typename Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count) { + int D0 = batch_count; + int D2 = extent.row() / D1; + int D3 = extent.column(); + return {D0, D1, D2, D3}; + } +}; + +template +struct PermuteInfo> { + + static bool constexpr kBatched = false; + static int constexpr kRowFactor = D1; + static int constexpr kColumnFactor = D2; + static int constexpr kBatchFactor = 1; + + using Layout = cutlass::layout::TensorNHWC; + + static std::string name() { + return "Tensor4DPermute0213<" + std::to_string(D1) + "," + std::to_string(D2) + ">"; + } + + static std::string desc() { + return "normal GEMM permutation [0, 2, 1, 3]"; + } + + static Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count) { + int D0 = extent.row() / D1; + int D3 = extent.column() / D2; + return {D0, D1, D2, D3}; + } + + static Layout::TensorCoord permute(Layout::TensorCoord const &s) { + return {s[0], s[2], s[1], s[3]}; + } +}; + +template +struct PermuteInfo> +: public PermuteInfo> { + + static bool constexpr kBatched = false; + static int constexpr kRowFactor = D2; + static int constexpr kColumnFactor = D1; + static int constexpr kBatchFactor = 1; + + using Base = PermuteInfo>; + using Layout = typename Base::Layout; + + static typename Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count) { + int D0 = extent.row() / D2; + int D3 = extent.column() / D1; + return {D0, D1, D2, D3}; + } +}; + +template +struct PermuteInfo> +: public PermuteInfo> { + using Layout = cutlass::layout::TensorCWHN; +}; + +template +struct PermuteInfo> +: public PermuteInfo> { + using Layout = cutlass::layout::TensorCWHN; +}; + +template +struct PermuteInfo> { + + static bool constexpr kBatched = false; + static int constexpr kRowFactor = T1; + static int constexpr kColumnFactor = T2 * T3; + static int constexpr kBatchFactor = 1; + + using Layout = cutlass::layout::TensorNDHWC; + + static std::string name() { + return "Tensor5DPermute20314<" + std::to_string(T1) + "," + std::to_string(T2) + "," + std::to_string(T3) + ">"; + } + + static std::string desc() { + return "normal GEMM permutation [2, 0, 3, 1, 4]"; + } + + static Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count) + { + int const T0 = extent.row() / T1; + int const T4 = extent.column() / (T2 * T3); + return {T0, T1, T2, T3, T4}; + } + + static Layout::TensorCoord permute(Layout::TensorCoord const &s) + { + return {s[2], s[0], s[3], s[1], s[4]}; + } +}; + +template +struct PermuteInfo> +: public PermuteInfo> { + + static bool constexpr kBatched = false; + static int constexpr kRowFactor = T2; + static int constexpr kColumnFactor = T1 * T3; + static int constexpr kBatchFactor = 1; + + using Base = PermuteInfo>; + using Layout = typename Base::Layout; + + static typename Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count) { + int const T0 = extent.row() / T2; + int const T4 = extent.column() / (T1 * T3); + return {T0, T1, T2, T3, T4}; + } +}; + +template +struct PermuteInfo> { + + static bool constexpr kBatched = false; + static int constexpr kRowFactor = T1; + static int constexpr kColumnFactor = T2 * T3; + static int constexpr kBatchFactor = 1; + + using Layout = cutlass::layout::TensorCWHDN; + + static std::string name() { + return "Tensor5DPermute02413<" + std::to_string(T1) + "," + std::to_string(T2) + "," + std::to_string(T3) + ">"; + } + + static std::string desc() { + return "normal GEMM permutation [0, 2, 4, 1, 3]"; + } + + using Coord = cutlass::Tensor5DCoord; + + static Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count) + { + int const T0 = extent.row() / T1; + int const T4 = extent.column() / (T2 * T3); + return {T0, T1, T2, T3, T4}; + } + + static Layout::TensorCoord permute(Layout::TensorCoord const &s) + { + return {s[0], s[2], s[4], s[1], s[3]}; + } +}; + +template +struct PermuteInfo> +: public PermuteInfo> { + + static bool constexpr kBatched = false; + static int constexpr kRowFactor = T2; + static int constexpr kColumnFactor = T1 * T3; + static int constexpr kBatchFactor = 1; + + using Base = PermuteInfo>; + using Layout = typename Base::Layout; + + static typename Layout::TensorCoord original_shape(cutlass::MatrixCoord extent, int batch_count) { + int const T0 = extent.row() / T2; + int const T4 = extent.column() / (T1 * T3); + return {T0, T1, T2, T3, T4}; + } +}; diff --git a/examples/40_cutlass_py/README.md b/examples/40_cutlass_py/README.md new file mode 100644 index 0000000000..c670e34072 --- /dev/null +++ b/examples/40_cutlass_py/README.md @@ -0,0 +1,4 @@ +# PyCUTLASS Examples + +This directory contains deprecated examples for PyCUTLASS, a precursor to the CUTLASS Python interface. +For examples of using CUTLASS's actively-maintained Pythonic interface, see the [examples/python](/examples/python) directory. diff --git a/examples/40_cutlass_py/conv2d.py b/examples/40_cutlass_py/conv2d.py new file mode 100644 index 0000000000..71e94259ff --- /dev/null +++ b/examples/40_cutlass_py/conv2d.py @@ -0,0 +1,177 @@ +################################################################################ +# +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################ +""" +Basic example of using the CUTLASS Python interface to run a 2d convolution +""" + +import sys +print("This example is deprecated. Please see examples/python for examples of using " + "the CUTLASS Python interface.") +sys.exit(0) + +import argparse +import numpy as np +import torch + +import cutlass_bindings +import cutlass.backend as pycutlass +from cutlass.backend import * +from cutlass.backend.utils.reference_model import Conv2dReferenceModule +from cutlass.backend.utils.device import device_cc + + +parser = argparse.ArgumentParser( + description=("Launch a 2d convolution kernel from Python. " + "See https://docs.nvidia.com/deeplearning/performance/dl-performance-convolutional/index.html#convo-intro for notation.")) +parser.add_argument("--n", default=1, type=int, help="N dimension of the convolution") +parser.add_argument("--c", default=64, type=int, help="C dimension of the convolution") +parser.add_argument("--h", default=32, type=int, help="H dimension of the convolution") +parser.add_argument("--w", default=32, type=int, help="W dimension of the convolution") +parser.add_argument("--k", default=32, type=int, help="N dimension of the convolution") +parser.add_argument("--r", default=3, type=int, help="R dimension of the convolution") +parser.add_argument("--s", default=3, type=int, help="S dimension of the convolution") +parser.add_argument('--print_cuda', action="store_true", help="Print the underlying CUDA kernel") + +try: + args = parser.parse_args() +except: + sys.exit(0) + +# Check that the device is of a sufficient compute capability +cc = device_cc() +assert cc >= 70, "The CUTLASS Python Conv2d example requires compute capability greater than or equal to 70." + +alignment = 1 + +np.random.seed(0) + +# Allocate a pool of device memory to be used by the kernel +pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32) + +# Set the compiler to use to NVCC +pycutlass.compiler.nvcc() + +# Set up A, B, C and accumulator +A = TensorDescription(cutlass_bindings.float16, cutlass_bindings.TensorNHWC, alignment) +B = TensorDescription(cutlass_bindings.float16, cutlass_bindings.TensorNHWC, alignment) +C = TensorDescription(cutlass_bindings.float32, cutlass_bindings.TensorNHWC, alignment) +element_acc = cutlass_bindings.float32 +element_epilogue = cutlass_bindings.float32 + +# Select instruction shape based on the Tensor Core instructions supported +# by the device on which we are running +if cc == 70: + instruction_shape = [8, 8, 4] +elif cc == 75: + instruction_shape = [16, 8, 8] +else: + # Use CUTLASS kernels for CC 80 by default (e.g., for cases in which SM86 is used) + cc = 80 + instruction_shape = [16, 8, 16] + +math_inst = MathInstruction( + instruction_shape, + A.element, B.element, element_acc, + cutlass_bindings.OpClass.TensorOp, + MathOperation.multiply_add +) + +tile_description = TileDescription( + [128, 128, 32], # Threadblock shape + 2, # Number of stages + [2, 2, 1], # Number of warps within each dimension of the threadblock shape + math_inst +) + +epilogue_functor = pycutlass.LinearCombination(C.element, C.alignment, element_acc, element_epilogue) + +operation = Conv2dOperation( + conv_kind=cutlass_bindings.conv.Operator.fprop, + iterator_algorithm=cutlass_bindings.conv.IteratorAlgorithm.optimized, + arch=cc, tile_description=tile_description, + A=A, B=B, C=C, stride_support=StrideSupport.Unity, + epilogue_functor=epilogue_functor +) + +if args.print_cuda: + print(operation.rt_module.emit()) + +operations = [operation, ] + +# Compile the operation +pycutlass.compiler.add_module(operations) + +# Randomly initialize tensors + +problem_size = cutlass_bindings.conv.Conv2dProblemSize( + cutlass_bindings.Tensor4DCoord(args.n, args.h, args.c, args.w), + cutlass_bindings.Tensor4DCoord(args.k, args.r, args.s, args.c), + cutlass_bindings.Tensor4DCoord(0, 0, 0, 0), # Padding + cutlass_bindings.MatrixCoord(1, 1), # Strides + cutlass_bindings.MatrixCoord(1, 1), # Dilation + cutlass_bindings.conv.Mode.cross_correlation, + 1, # Split k slices + 1 # Groups +) + +tensor_A_size = cutlass_bindings.conv.implicit_gemm_tensor_a_size(operation.conv_kind, problem_size) +tensor_B_size = cutlass_bindings.conv.implicit_gemm_tensor_b_size(operation.conv_kind, problem_size) +tensor_C_size = cutlass_bindings.conv.implicit_gemm_tensor_c_size(operation.conv_kind, problem_size) + +tensor_A = torch.ceil(torch.empty(size=(tensor_A_size,), dtype=torch.float16, device="cuda").uniform_(-8.5, 7.5)) +tensor_B = torch.ceil(torch.empty(size=(tensor_B_size,), dtype=torch.float16, device="cuda").uniform_(-8.5, 7.5)) +tensor_C = torch.ceil(torch.empty(size=(tensor_C_size,), dtype=torch.float32, device="cuda").uniform_(-8.5, 7.5)) +tensor_D = torch.ones(size=(tensor_C_size,), dtype=torch.float32, device="cuda") + +alpha = 1. +beta = 0. + +arguments = Conv2dArguments( + operation=operation, problem_size=problem_size, + A=tensor_A, B=tensor_B, C=tensor_C, D=tensor_D, + output_op=operation.epilogue_type(alpha, beta) +) + +# Run the operation +operation.run(arguments) +arguments.sync() + +# Run the host reference module and compare to the CUTLASS result +reference = Conv2dReferenceModule(A, B, C, operation.conv_kind) +tensor_D_ref = reference.run(tensor_A, tensor_B, tensor_C, problem_size, alpha, beta) + +try: + assert torch.equal(tensor_D, tensor_D_ref) +except: + assert torch.allclose(tensor_D, tensor_D_ref, rtol=1e-2) + +print("Passed.") diff --git a/examples/40_cutlass_py/customizable/README.md b/examples/40_cutlass_py/customizable/README.md new file mode 100644 index 0000000000..e8aeee9e71 --- /dev/null +++ b/examples/40_cutlass_py/customizable/README.md @@ -0,0 +1,167 @@ +# Customizable Python Interface Examples +This directory contains examples of using the CUTLASS Python interface with a variety of configurations for kernels. + +For all the tests, add `--print_cuda` to print the underlying CUDA kernel. Use `-h` or `--help` to display the help message. + +## GEMM Examples +The GEMM examples use numpy to create input tensors and verify the results. +### GEMM F64 Example +Example 1: SM80_Device_Gemm_f64t_f64n_f64n_tensor_op_f64_32x32x16_16x16x16 +```python +python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 32 32 16 -s 4 -w 2 2 1 -cc 80 -la ColumnMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1 +``` +Example 2: SM80_Device_Gemm_f64n_f64t_f64n_tensor_op_f64_64x64x16_32x32x16, split_k(2)_serial +```python +python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb ColumnMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 2 +``` + +### GEMM F32 Example +Example 1: SM80_Device_Gemm_f32n_f32t_f32n_tensor_op_bf16_f32_128x128x32_64x64x32 +```python +python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1 +``` +Example 2: SM80_Device_Gemm_f32t_f32t_f32n_tensor_op_f32_128x128x32_64x64x32, split_k(2)_parallel +```python +python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm GemmSplitKParallel -k 2 +``` +Example 3: SM80_Device_Gemm_f32t_f32t_f32n_tensor_op_fast_accurate_f32_64x64x32_32x32x32, split_k(4)_serial +```python +python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_f32 -op TensorOp -b 64 64 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 4 +``` + +### GEMM F16 Example +Example 1: SM80_Device_Gemm_f32t_f32n_f32t_tensor_op_bf16_f32_128x128x32_64x64x32 +```python +python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb RowMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle4 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1 +``` +Example 2: SM80_Device_Gemm_f16t_f16t_f16n_tensor_op_f32_128x128x64_64x64x64, split_k(2)_serial +```python +python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 2 +``` +Example 3: SM80_Device_Gemm_f16t_f16t_f32n_tensor_op_f32_256x128x64_64x64x64, split_k(3)_serial +```python +python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 256 128 64 -s 3 -w 4 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm GemmSplitKParallel -k 3 +``` + +### GEMM BF16 Example +Example 1: Device_Gemm_bf16t_bf16t_f32n_tensor_op_f32_64x128x64_32x64x64, split_k(5)_parallel +```python +python gemm.py -i 16 8 16 -ta bfloat16 -tb bfloat16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 64 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm GemmSplitKParallel -k 5 +``` + +### GEMM Int8 Example +Example 1: SM80_Device_Gemm_s8n_s8t_s8n_tensor_op_s32_256x128x128_64x64x128 +```python +python gemm.py -i 16 8 32 -ta int8 -tb int8 -tc int8 -tacc int32 -m multiply_add -op TensorOp -b 128 128 128 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 16 -lb ColumnMajor -ab 16 -lc RowMajor -ac 16 -te float32 -ep FastLinearCombinationClamp -sw IdentitySwizzle2 -p 512 512 512 -alpha 1.0 -beta 0.0 -gm Gemm -k 1 +``` + +### Batched & Array GEMM +Example 1: Batched GEMM +```python +python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add_fast_bf16 -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la RowMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw BatchedIdentitySwizzle -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Batched -k 1 -batch 3 +``` +Example 2: Array GEMM +```python +python gemm.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb RowMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle4 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Array -k 1 -batch 2 +``` +*** +## GEMM Grouped Examples +The GEMM Grouped examples use numpy to create input tensors and verify the results. + +Example 1: SM80_Device_GemmGrouped_f16t_f16t_f32t_tensor_op_f32_128x128x32_64x64x32, device schedule +```python +python gemm_grouped.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 1.0 -beta 0.0 -pm Device +``` +Example 2: SM80_Device_GemmGrouped_f64n_f64n_f64t_tensor_op_f64_64x64x16_32x32x16, host schedule +```python +python gemm_grouped.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb RowMajor -ab 1 -lc ColumnMajor -ac 1 -te float64 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 1.0 -beta 1.0 -pm Host +``` +Example 3: SM80_Device_GemmGrouped_f32n_f32n_f32n_simt_f32_128x64x8_64x32x1, device schedule +```python +python gemm_grouped.py -i 1 1 1 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op Simt -b 128 64 8 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float32 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 2.0 -beta 1.0 -pm Device +``` +Example 4: SM80_Device_GemmGrouped_f16t_f16t_f32t_tensor_op_f32_128x128x32_64x64x32, device schedule +```python +python gemm_grouped.py -i 16 8 16 -ta float16 -tb float16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc ColumnMajor -ac 4 -te float32 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 2.0 -beta 1.0 -pm Device +``` +*** +## Conv2d Example +The Conv2d examples use pytorch to create input tensors and verify the results. Pytorch can be installed following the [official website](https://pytorch.org/get-started/locally/). +### Conv2d F32 Fprop +Example 1: SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32 +```python +python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 16 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co fprop -st Strided -ia optimized -sm Serial -k 1 -nhwc 1 13 17 8 -krsc 24 3 3 8 -pad 0 0 0 0 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 0.0 +``` +Example 2: SM80_Device_Conv2d_Fprop_Optimized_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_align2 +```python +python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 16 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 2 -lb TensorNHWC -ab 2 -lc TensorNHWC -ac 2 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -co fprop -st Strided -ia optimized -sm Serial -k 2 -nhwc 1 4 4 12 -krsc 8 3 3 12 -pad 0 0 0 0 -stride 3 3 -dilation 1 1 -alpha 1.0 -beta 1.0 +``` +Example 3: SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32 +```python +python conv2d.py -i 1 1 1 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op Simt -b 128 128 8 -s 4 -w 4 2 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 1 -te float32 -ep LinearCombination -sw IdentitySwizzle4 -co fprop -st Strided -ia analytic -sm Parallel -k 3 -nhwc 1 71 80 32 -krsc 64 5 5 32 -pad 2 2 2 2 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 1.0 +``` +### Conv2d F32 Wgrad +Example 1: Device_Conv2d_Wgrad_Optimized_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32_align1 +```python +python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 1 -lb TensorNHWC -ab 1 -lc TensorNHWC -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co wgrad -st Strided -ia optimized -sm Serial -k 1 -nhwc 1 8 8 1 -krsc 1 3 3 1 -pad 1 1 1 1 -stride 1 1 -dilation 1 1 -alpha 1.0 -beta 0.0 +``` +Example 2: Device_Conv2d_Wgrad_Analytic_ImplicitGemm_f32nhwc_f32nhwc_f32nhwc_simt_f32 +```python +python conv2d.py -i 1 1 1 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op Simt -b 128 128 8 -s 4 -w 2 4 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 1 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co wgrad -st Strided -ia optimized -sm Serial -k 2 -nhwc 1 27 27 256 -krsc 512 3 3 256 -pad 1 1 1 1 -stride 2 1 -dilation 1 1 -alpha 1.0 -beta 0.0 +``` +### Conv2d F32 Dgrad +Example 1: Device_Conv2d_Dgrad_Analytic_ImplicitGemm_tf32nhwc_tf32nhwc_f32nhwc_tensor_op_f32 +```python +python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 16 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 4 -te float32 -ep LinearCombination -sw StridedDgradIdentitySwizzle1 -co dgrad -st Strided -ia optimized -sm Serial -k 2 -nhwc 1 27 27 256 -krsc 512 3 3 256 -pad 1 1 1 1 -stride 2 1 -dilation 1 1 -alpha 1.0 -beta 0.0 +``` + +### Conv2d F16 Fprop +Example 1: SM80_Device_Conv2d_Fprop_Analytic_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32 +```python +python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 8 -lb TensorNHWC -ab 8 -lc TensorNHWC -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co fprop -st Strided -ia optimized -sm Serial -k 1 -nhwc 1 27 27 256 -krsc 512 3 3 256 -pad 1 1 1 1 -stride 2 1 -dilation 1 1 -alpha 1.0 -beta 0.0 +``` +Example 2: SM80_Device_Conv2d_Fprop_Few_Channels_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_channels_2 +```python +python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 2 -lb TensorNHWC -ab 2 -lc TensorNHWC -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co fprop -st Strided -ia few_channels -sm Serial -k 1 -nhwc 1 16 16 2 -krsc 16 3 3 2 -pad 1 1 1 1 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 0.0 +``` +Example 3: SM80_Device_Conv2d_Fprop_Fixed_Channels_ImplicitGemm_f16nhwc_f16nhwc_f16nhwc_tensor_op_f32_channels_8 +```python +python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 8 -lb TensorNHWC -ab 8 -lc TensorNHWC -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -co fprop -st Strided -ia fixed_channels -sm Serial -k 1 -nhwc 1 8 8 8 -krsc 16 3 3 8 -pad 1 1 1 1 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 0.0 +``` +Example 4: SM80_Device_Conv2d_Strided_Dgrad_Optimized_ImplicitGemm_f16nhwc_f16nhwc_f32nhwc_tensor_op_f32_128x128_32x3_64x64x32_align4 +```python +python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 4 -lb TensorNHWC -ab 4 -lc TensorNHWC -ac 4 -te float32 -ep LinearCombination -sw StridedDgradIdentitySwizzle1 -co dgrad -st Strided -ia optimized -sm Serial -k 1 -nhwc 1 56 56 12 -krsc 8 1 1 12 -pad 0 0 0 0 -stride 2 2 -dilation 1 1 -alpha 1.0 -beta 0.0 +``` + +## Epilogue +### Bias +To replace C with a bias vector, add `-bias` flag. +### Activation function +Example 1: ReLU +```python +python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 32 32 16 -s 4 -w 2 2 1 -cc 80 -la ColumnMajor -aa 1 -lb RowMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 1 -bias -activ relu +``` +Example 2: leaky ReLU +```python +python gemm.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb ColumnMajor -ab 1 -lc RowMajor -ac 1 -te float64 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm Gemm -k 2 -bias -activ leaky_relu -activ_arg 0.2 +``` +Example 3: tanh (alpha=0 to avoid saturation) +```python +python gemm.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 32 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 4 -lb ColumnMajor -ab 4 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -p 512 256 128 -alpha 1.0 -beta 0.5 -gm GemmSplitKParallel -k 2 -bias -activ tanh +``` +Example 4: sigmoid +```python +python gemm_grouped.py -i 8 8 4 -ta float64 -tb float64 -tc float64 -tacc float64 -m multiply_add -op TensorOp -b 64 64 16 -s 4 -w 2 2 1 -cc 80 -la RowMajor -aa 1 -lb RowMajor -ab 1 -lc ColumnMajor -ac 1 -te float64 -ep LinearCombination -p ./grouped_gemm_problem_size.csv -alpha 0.0 -beta 0.5 -pm Host -bias -activ sigmoid -bias -activ sigmoid +``` +Example 5: SiLU +```python +python conv2d.py -i 16 8 8 -ta float32 -tb float32 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 128 128 16 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 2 -lb TensorNHWC -ab 2 -lc TensorNHWC -ac 2 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -co fprop -st Strided -ia optimized -sm Serial -k 2 -nhwc 1 4 4 12 -krsc 8 3 3 12 -pad 0 0 0 0 -stride 3 3 -dilation 1 1 -alpha 0.0 -beta 0.5 -bias -activ silu +``` +Example 6: HardSwish +```python +python conv2d.py -i 16 8 16 -ta float16 -tb float16 -tc float16 -tacc float32 -m multiply_add -op TensorOp -b 128 128 64 -s 3 -w 2 2 1 -cc 80 -la TensorNHWC -aa 2 -lb TensorNHWC -ab 2 -lc TensorNHWC -ac 8 -te float32 -ep LinearCombination -sw IdentitySwizzle1 -co fprop -st Strided -ia few_channels -sm Serial -k 1 -nhwc 1 16 16 2 -krsc 16 3 3 2 -pad 1 1 1 1 -stride 2 2 -dilation 1 1 -alpha 0.0 -beta 0.5 -bias -activ hardswish +``` +Example 7: GELU +```python +python gemm.py -i 16 8 16 -ta bfloat16 -tb bfloat16 -tc float32 -tacc float32 -m multiply_add -op TensorOp -b 64 128 64 -s 3 -w 2 2 1 -cc 80 -la ColumnMajor -aa 8 -lb ColumnMajor -ab 8 -lc RowMajor -ac 4 -te float32 -ep LinearCombination -sw IdentitySwizzle2 -p 512 256 128 -alpha 0.0 -beta 0.5 -gm GemmSplitKParallel -k 5 -bias -activ gelu +``` diff --git a/examples/40_cutlass_py/customizable/conv2d.py b/examples/40_cutlass_py/customizable/conv2d.py new file mode 100644 index 0000000000..c6cbf87a8d --- /dev/null +++ b/examples/40_cutlass_py/customizable/conv2d.py @@ -0,0 +1,331 @@ +################################################################################ +# +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################ + +import sys +print("This example is deprecated. Please see examples/python for examples of using " + "the CUTLASS Python interface.") +sys.exit(0) + +import numpy as np +import cutlass.backend as pycutlass +from cutlass.backend import * +from cutlass.backend.utils.device import device_cc +from cutlass.backend.conv2d_operation import * +from cutlass.backend.utils.reference_model import Conv2dReferenceModule +import torch.nn.functional as F + +import argparse + +# parse the arguments +parser = argparse.ArgumentParser(description="Launch CUTLASS convolution 2d kernels from Python") + +# Operation description +# math instruction description +parser.add_argument("-i", "--instruction_shape", + default=[1, 1, 1], nargs=3, type=int, + help="This option describes the size of MMA op") +parser.add_argument("-ta", "--element_a", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], + help='Data type of elements in input tensor A') +parser.add_argument("-tb", "--element_b", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], + help='Data type of elements in input tensor B') +parser.add_argument("-tc", "--element_c", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], + help='Data type of elements in input tensor C and output tensor D') +parser.add_argument("-tacc", "--element_acc", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], + help='Data type of accumulator') +parser.add_argument('-m', "--math", default="multiply_add", + type=str, choices=["multiply_add", "multiply_add_fast_bf16", "multiply_add_fast_f32"], help="math instruction") +parser.add_argument('-op', "--opcode", default="Simt", type=str, + choices=["Simt", 'TensorOp'], + help='This option describes whether you want to use tensor \ + cores (TensorOp) or regular SIMT cores (Simt) on GPU SM') +# tile description +parser.add_argument("-b", "--threadblock_shape", + default=[128, 128, 8], nargs=3, type=int, + help="This option describes the tile size a thread block with compute") +parser.add_argument("-s", "--stages", default=4, + type=int, help="Number of pipelines you want to use") +parser.add_argument("-w", "--warp_count", default=[ + 4, 2, 1], nargs=3, type=int, + help="This option describes the number of warps along M, N, and K of the threadblock") +parser.add_argument("-cc", "--compute_capability", default=80, + type=int, help="This option describes CUDA SM architecture number") +# A +parser.add_argument('-la', "--layout_a", default="TensorNHWC", type=str, choices=[ + "TensorNHWC", "TensorNC32HW32"], + help="Memory layout of input tensor A") +parser.add_argument('-aa', '--alignment_a', default=1, + type=int, help="Memory alignement of input tensor A") +# B +parser.add_argument('-lb', "--layout_b", default="TensorNHWC", type=str, choices=[ + "TensorNHWC", "TensorC32RSK32"], + help="Memory layout of input tensor B") +parser.add_argument('-ab', '--alignment_b', default=1, + type=int, help="Memory alignment of input tensor B") +# C +parser.add_argument('-lc', "--layout_c", default="TensorNHWC", type=str, choices=[ + "TensorNHWC", "TensorNC32HW32"], + help="Memory layout of input tensor C and output tensor D") +parser.add_argument('-ac', '--alignment_c', default=1, + type=int, help="Memory alignment of input tensor C and output tensor D") +# epilogue +parser.add_argument("-te", "--element_epilogue", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16'], + help='Data type of computation in the epilogue') +parser.add_argument("-ep", "--epilogue_functor", default="LinearCombination", + type=str, choices=['LinearCombination', 'FastLinearCombinationClamp', 'LinearCombinationClamp'], + help="This option describes the epilogue part of the kernel") +# swizzling +parser.add_argument("-sw", "--swizzling_functor", default="IdentitySwizzle1", type=str, choices=[ + "IdentitySwizzle1", "IdentitySwizzle2", "IdentitySwizzle4", "IdentitySwizzle8", + "HorizontalSwizzle", "StridedDgradIdentitySwizzle1", "StridedDgradIdentitySwizzle4", + "StridedDgradHorizontalSwizzle"], + help="This option describes how thread blocks are scheduled on GPU") +# conv related +parser.add_argument("-co", "--conv_kind", default="fprop", type=str, choices=['fprop', 'dgrad', 'wgrad'], + help="The type of convolution: forward propagation (fprop), \ + gradient of activation (dgrad), gradient of weight (wgrad)") +parser.add_argument("-st", "--stride_support", default="Strided", type=str, choices=["Strided", "Unity"], + ) +parser.add_argument("-ia", "--iterator_algorithm", default="analytic", type=str, + choices=["analytic", "optimized", "fixed_channels", "few_channels"], + help="This option describes iterator algorithm") + +# arguments +parser.add_argument("-sm", "--split_k_mode", default="Serial", type=str, choices=["Serial", "Parallel"], + help="Split K Mode. Serial is used for non-splitK or serial-splitK.\ + Parallel is used for parallel splitK.") +parser.add_argument('-k', '--split_k_slices', default=1, + type=int, help="Number of split-k partitions. (default 1)") +parser.add_argument("-nhwc", "--nhwc", nargs=4, type=int, help="input size (NHWC)") +parser.add_argument("-krsc", "--krsc", nargs=4, type=int, help="filter size (KRSC)") +parser.add_argument("-pad", "--pad", nargs=4, type=int, help="padding (pad_h, _, pad_w, _)") +parser.add_argument("-stride", "--stride", nargs=2, type=int, help="stride (stride_h, stride_w)") +parser.add_argument("-dilation", "--dilation", nargs=2, type=int, help="dilation (dilation_h, dilation_w)") +parser.add_argument("-alpha", "--alpha", default=1.0, type=float, help="alpha") +parser.add_argument("-beta", "--beta", default=0.0, type=float, help="beta") +parser.add_argument('-bias', '--bias', action='store_true', help="C is bias vector") +# Activation function +parser.add_argument("-activ", "--activation_function", default="identity", + choices=["identity", "relu", "leaky_relu", "tanh", "sigmoid", "silu", "hardswish", "gelu"], help="activation function") +parser.add_argument("-activ_arg", "--activation_args", default=[], nargs="+", type=float, + help="addition arguments for activation") + + +parser.add_argument('--print_cuda', action="store_true", + help="print the underlying CUDA kernel") + +try: + args = parser.parse_args() +except: + sys.exit(0) + +cc = device_cc() +if args.compute_capability != cc: + raise Exception(("Parameter --compute-capability of {} " + "does not match that of the device of {}.").format(args.compute_capability, cc)) + +pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32) + +np.random.seed(0) + +element_a = getattr(cutlass_bindings, args.element_a) +element_b = getattr(cutlass_bindings, args.element_b) +element_c = getattr(cutlass_bindings, args.element_c) +element_acc = getattr(cutlass_bindings, args.element_acc) +math_operation = getattr(MathOperation, args.math) +opclass = getattr(cutlass_bindings.OpClass, args.opcode) + +math_inst = MathInstruction( + args.instruction_shape, element_a, element_b, + element_acc, opclass, math_operation +) + +tile_description = TileDescription( + args.threadblock_shape, args.stages, args.warp_count, + math_inst +) + +layout_a = getattr(cutlass_bindings, args.layout_a) +layout_b = getattr(cutlass_bindings, args.layout_b) +layout_c = getattr(cutlass_bindings, args.layout_c) + +A = TensorDescription( + element_a, layout_a, args.alignment_a +) + +B = TensorDescription( + element_b, layout_b, args.alignment_b +) + +C = TensorDescription( + element_c, layout_c, args.alignment_c +) + +element_epilogue = getattr(cutlass_bindings, args.element_epilogue) +if (args.activation_function == "identity" + or (args.split_k_mode == "Parallel" and args.split_k_slices > 1)): + # + epilogue_functor = getattr(pycutlass, args.epilogue_functor)( + C.element, C.alignment, math_inst.element_accumulator, element_epilogue) +else: + epilogue_functor = getattr(pycutlass, "LinearCombinationGeneric")( + getattr(pycutlass, args.activation_function)(element_epilogue), + C.element, C.alignment, math_inst.element_accumulator, element_epilogue) + +iterator_algorithm = getattr(cutlass_bindings.conv.IteratorAlgorithm, args.iterator_algorithm) +swizzling_functor = getattr(cutlass_bindings, args.swizzling_functor) +stride_support = getattr(StrideSupport, args.stride_support) +conv_kind = getattr(cutlass_bindings.conv.Operator, args.conv_kind) + +operation = Conv2dOperation( + conv_kind=conv_kind, iterator_algorithm=iterator_algorithm, + arch=args.compute_capability, tile_description=tile_description, + A=A, B=B, C=C, stride_support=stride_support, + epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor +) + +if args.print_cuda: + print(operation.rt_module.emit()) + +operations = [operation,] + +if args.split_k_mode == "Parallel" and args.split_k_slices > 1: + if (args.activation_function == "identity"): + epilogue_functor_reduction = getattr(pycutlass, args.epilogue_functor)( + C.element, C.alignment, math_inst.element_accumulator, element_epilogue) + else: + epilogue_functor_reduction = getattr(pycutlass, "LinearCombinationGeneric")( + getattr(pycutlass, args.activation_function)(element_epilogue), + C.element, C.alignment, math_inst.element_accumulator, element_epilogue) + reduction_operation = ReductionOperation( + shape=cutlass_bindings.MatrixCoord(4, 32 * C.alignment), + C=C, element_accumulator=element_acc, + element_compute=element_epilogue, + epilogue_functor=epilogue_functor_reduction, + count=C.alignment + ) + operations.append(reduction_operation) + +pycutlass.compiler.add_module(operations) + +problem_size = cutlass_bindings.conv.Conv2dProblemSize( + cutlass_bindings.Tensor4DCoord(args.nhwc[0], args.nhwc[1], args.nhwc[2], args.nhwc[3]), + cutlass_bindings.Tensor4DCoord(args.krsc[0], args.krsc[1], args.krsc[2], args.krsc[3]), + cutlass_bindings.Tensor4DCoord(args.pad[0], args.pad[1], args.pad[2], args.pad[3]), + cutlass_bindings.MatrixCoord(args.stride[0], args.stride[1]), + cutlass_bindings.MatrixCoord(args.dilation[0], args.dilation[1]), + cutlass_bindings.conv.Mode.cross_correlation, + args.split_k_slices, 1 +) + + +# User-provide inputs +tensor_A_size = cutlass_bindings.conv.implicit_gemm_tensor_a_size( + conv_kind, problem_size +) +tensor_B_size = cutlass_bindings.conv.implicit_gemm_tensor_b_size( + conv_kind, problem_size +) +if args.bias: + tensor_C_size = cutlass_bindings.conv.implicit_gemm_tensor_c_extent( + conv_kind, problem_size + ).at(3) +else: + tensor_C_size = cutlass_bindings.conv.implicit_gemm_tensor_c_size( + conv_kind, problem_size + ) + +tensor_D_size = cutlass_bindings.conv.implicit_gemm_tensor_c_size( + conv_kind, problem_size + ) + +if args.element_a != "int8": + tensor_A = torch.ceil(torch.empty(size=(tensor_A_size,), dtype=getattr(torch, args.element_a), device="cuda").uniform_(-8.5, 7.5)) +else: + tensor_A = torch.empty(size=(tensor_A_size,), dtype=getattr(torch, args.element_a), device="cuda").uniform_(-2, 2) + +if args.element_b != "int8": + tensor_B = torch.ceil(torch.empty(size=(tensor_B_size,), dtype=getattr(torch, args.element_b), device="cuda").uniform_(-8.5, 7.5)) +else: + tensor_B = torch.empty(size=(tensor_B_size,), dtype=getattr(torch, args.element_b), device="cuda").uniform_(-2, 2) + +if args.element_c != "int8": + tensor_C = torch.ceil(torch.empty(size=(tensor_C_size,), dtype=getattr(torch, args.element_c), device="cuda").uniform_(-8.5, 7.5)) +else: + tensor_C = torch.empty(size=(tensor_C_size,), dtype=getattr(torch, args.element_c), device="cuda").uniform_(-2, 2) + +tensor_D = torch.ones(size=(tensor_D_size,), dtype=getattr(torch, args.element_c), device="cuda") + +arguments = Conv2dArguments( + operation=operation, problem_size=problem_size, A=tensor_A, + B=tensor_B, C=tensor_C, D=tensor_D, + output_op = operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args)), + split_k_mode=getattr(cutlass_bindings.conv.SplitKMode, args.split_k_mode), + split_k_slices=problem_size.split_k_slices +) + +if args.split_k_mode == "Parallel" and args.split_k_slices > 1: + implicit_gemm_size = cutlass_bindings.conv.implicit_gemm_problem_size(conv_kind, arguments.problem_size) + reduction_arguments = ReductionArguments( + reduction_operation, + problem_size=[implicit_gemm_size.m(), implicit_gemm_size.n()], + partitions=problem_size.split_k_slices, + workspace=arguments.ptr_D, + destination=tensor_D, + source=tensor_C, + output_op = reduction_operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args)), + bias = arguments.bias + ) + +operation.run(arguments) + +if args.split_k_mode == "Parallel" and args.split_k_slices > 1: + reduction_operation.run(reduction_arguments) + reduction_arguments.sync() +else: + arguments.sync() + +reference_model = Conv2dReferenceModule(A, B, C, conv_kind) + +tensor_D_ref = reference_model.run(tensor_A, tensor_B, tensor_C, arguments.problem_size, args.alpha, args.beta, args.bias) +if (args.activation_function != "identity"): + tensor_D_ref = getattr(F, args.activation_function)(*([tensor_D_ref,] + args.activation_args)) + +try: + assert torch.equal(tensor_D, tensor_D_ref) +except: + assert torch.allclose(tensor_D, tensor_D_ref, rtol=1e-2) +print("Passed.") diff --git a/examples/40_cutlass_py/customizable/gemm.py b/examples/40_cutlass_py/customizable/gemm.py new file mode 100644 index 0000000000..670294ad2d --- /dev/null +++ b/examples/40_cutlass_py/customizable/gemm.py @@ -0,0 +1,331 @@ +################################################################################ +# +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################ + +import sys +print("This example is deprecated. Please see examples/python for examples of using " + "the CUTLASS Python interface.") +sys.exit(0) + +import numpy as np +import cutlass.backend as pycutlass +from cutlass.backend import * +from cutlass.backend.utils.device import device_cc +import cutlass_bindings +from bfloat16 import bfloat16 + +import argparse + + +# parse the arguments +parser = argparse.ArgumentParser(description="Launch CUTLASS GEMM kernels from Python: 'D = alpha * A * B + beta * C'") + +# Operation description +# math instruction description +parser.add_argument("-i", "--instruction_shape", + default=[1, 1, 1], nargs=3, type=int, + help="This option describes the size of MMA op") +parser.add_argument("-ta", "--element_a", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], + help='Data type of elements in input tensor A') +parser.add_argument("-tb", "--element_b", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], + help='Data type of elements in input tensor B') +parser.add_argument("-tc", "--element_c", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], + help='Data type of elements in input tensor C and output tensor D') +parser.add_argument("-tacc", "--element_acc", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], + help='Data type of accumulator') +parser.add_argument('-m', "--math", default="multiply_add", + type=str, choices=["multiply_add", "multiply_add_fast_bf16", "multiply_add_fast_f32"], help="math instruction") +parser.add_argument('-op', "--opcode", default="Simt", type=str, + choices=["Simt", 'TensorOp'], + help="This option describes whether you want to use tensor \ + cores (TensorOp) or regular SIMT cores (Simt) on GPU SM") +# tile description +parser.add_argument("-b", "--threadblock_shape", + default=[128, 128, 8], nargs=3, type=int, + help="This option describes the tile size a thread block with compute") +parser.add_argument("-s", "--stages", default=4, + type=int, help="Number of pipelines you want to use") +parser.add_argument("-w", "--warp_count", default=[4, 2, 1], nargs=3, type=int, + help="This option describes the number of warps along M, N, and K of the threadblock") +parser.add_argument("-cc", "--compute_capability", default=80, + type=int, help="This option describes CUDA SM architecture number") +# A +parser.add_argument('-la', "--layout_a", default="RowMajor", type=str, choices=[ + "RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"], + help="Memory layout of input tensor A") +parser.add_argument('-aa', '--alignment_a', default=1, + type=int, help="Memory alignement of input tensor A") +# B +parser.add_argument('-lb', "--layout_b", default="RowMajor", type=str, choices=[ + "RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"], + help="Memory layout of input tensor B") +parser.add_argument('-ab', '--alignment_b', default=1, + type=int, help="Memory alignment of input tensor B") +# C +parser.add_argument('-lc', "--layout_c", default="RowMajor", type=str, choices=[ + "RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"], + help="Memory layout of input tensor C and output tensor D") +parser.add_argument('-ac', '--alignment_c', default=1, + type=int, help="Memory alignment of input tensor C and output tensor D") +# epilogue +parser.add_argument("-te", "--element_epilogue", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16'], help='Epilogue datatype') +parser.add_argument("-ep", "--epilogue_functor", default="LinearCombination", + type=str, choices=['LinearCombination', 'FastLinearCombinationClamp', 'LinearCombinationClamp'], + help="This option describes the epilogue part of the kernel") +# swizzling +parser.add_argument("-sw", "--swizzling_functor", default="IdentitySwizzle1", type=str, choices=[ + "IdentitySwizzle1", "IdentitySwizzle2", "IdentitySwizzle4", "IdentitySwizzle8", "HorizontalSwizzle", "BatchedIdentitySwizzle"], + help="This option describes how thread blocks are scheduled on GPU") + +# Argument +parser.add_argument("-p", "--problem_size", + default=[128, 128, 128], nargs=3, type=int, + help="GEMM problem size M, N, K") +parser.add_argument("-alpha", "--alpha", default=1.0, type=float, + help="Scaling factor of A * B") +parser.add_argument("-beta", "--beta", default=0.0, type=float, + help="Scaling factor of C") +parser.add_argument("-gm", "--gemm_mode", default="Gemm", type=str, + choices=["Gemm", "GemmSplitKParallel", "Batched", "Array"], + help="GEMM mode. Gemm is used for non-splitK or serial-splitK. \ + GemmSplitKParallel is used for parallel splitK") +parser.add_argument('-k', '--split_k_slices', default=1, + type=int, help="Number of split-k partitions. (default 1)") +parser.add_argument('-bias', '--bias', action='store_true', help="C is bias vector") +parser.add_argument('-batch', '--batch', default=1, type=int, help="batch size for batched GEMM") + +# Activation function +parser.add_argument("-activ", "--activation_function", default="identity", + choices=["identity", "relu", "leaky_relu", "tanh", "sigmoid", "silu", "hardswish", "gelu"], help="activation function") +parser.add_argument("-activ_arg", "--activation_args", default=[], nargs="+", type=float, + help="addition arguments for activation") +parser.add_argument('--print_cuda', action="store_true", + help="print the underlying CUDA kernel") + +try: + args = parser.parse_args() +except: + sys.exit(0) + +cc = device_cc() +if args.compute_capability != cc: + raise Exception(("Parameter --compute-capability of {} " + "does not match that of the device of {}.").format(args.compute_capability, cc)) + +pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32) +pycutlass.compiler.nvcc() + +np.random.seed(0) + +element_a = getattr(cutlass_bindings, args.element_a) +element_b = getattr(cutlass_bindings, args.element_b) +element_c = getattr(cutlass_bindings, args.element_c) +element_acc = getattr(cutlass_bindings, args.element_acc) +math_operation = getattr(MathOperation, args.math) +opclass = getattr(cutlass_bindings.OpClass, args.opcode) + +math_inst = MathInstruction( + args.instruction_shape, element_a, element_b, + element_acc, opclass, math_operation +) + +tile_description = TileDescription( + args.threadblock_shape, args.stages, args.warp_count, + math_inst +) + +layout_a = getattr(cutlass_bindings, args.layout_a) +layout_b = getattr(cutlass_bindings, args.layout_b) +layout_c = getattr(cutlass_bindings, args.layout_c) + +A = TensorDescription( + element_a, layout_a, args.alignment_a +) + +B = TensorDescription( + element_b, layout_b, args.alignment_b +) + +C = TensorDescription( + element_c, layout_c, args.alignment_c +) + +element_epilogue = getattr(cutlass_bindings, args.element_epilogue) +if (args.activation_function == "identity" + or (args.gemm_mode == "GemmSplitKParallel" and args.split_k_slices > 1)): + # + epilogue_functor = getattr(pycutlass, args.epilogue_functor)( + C.element, C.alignment, math_inst.element_accumulator, element_epilogue) +else: + epilogue_functor = getattr(pycutlass, "LinearCombinationGeneric")( + getattr(pycutlass, args.activation_function)(element_epilogue), + C.element, C.alignment, math_inst.element_accumulator, element_epilogue) + +swizzling_functor = getattr(cutlass_bindings, args.swizzling_functor) + +operation = GemmOperationUniversal( + arch=args.compute_capability, tile_description=tile_description, + A=A, B=B, C=C, + epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor +) + +if args.print_cuda: + print(operation.rt_module.emit()) + +operations = [operation, ] + +if args.gemm_mode == "GemmSplitKParallel": + if (args.activation_function == "identity"): + epilogue_functor_reduction = getattr(pycutlass, args.epilogue_functor)( + C.element, C.alignment, math_inst.element_accumulator, element_epilogue) + else: + epilogue_functor_reduction = getattr(pycutlass, "LinearCombinationGeneric")( + getattr(pycutlass, args.activation_function)(element_epilogue), + C.element, C.alignment, math_inst.element_accumulator, element_epilogue) + + reduction_operation = ReductionOperation( + shape=cutlass_bindings.MatrixCoord(4, 32 * C.alignment), + C=C, element_accumulator=element_acc, + element_compute=element_epilogue, + epilogue_functor=epilogue_functor_reduction, + count=C.alignment + ) + operations.append(reduction_operation) + +pycutlass.compiler.add_module(operations) + +# User-provide inputs + +problem_size = cutlass_bindings.gemm.GemmCoord( + args.problem_size[0], args.problem_size[1], args.problem_size[2]) + +tensor_a_size = args.batch * problem_size.m() * problem_size.k() +if args.element_a != "int8": + if args.element_a == "bfloat16": + tensor_A = np.ceil( + np.random.uniform(low=-8.5, high=7.5, size=(tensor_a_size,)) + ).astype(bfloat16) + else: + tensor_A = np.ceil( + np.random.uniform(low=-8.5, high=7.5, size=(tensor_a_size,)) + ).astype(getattr(np, args.element_a)) +else: + tensor_A = np.random.uniform( + low=-2, high=2,size=(tensor_a_size,) + ).astype(getattr(np, args.element_a)) + +tensor_b_size = args.batch * problem_size.k() * problem_size.n() +if args.element_b != "int8": + if args.element_b == "bfloat16": + tensor_B = np.ceil( + np.random.uniform(low=-8.5, high=7.5, size=(tensor_b_size,)) + ).astype(bfloat16) + else: + tensor_B = np.ceil( + np.random.uniform(low=-8.5, high=7.5, size=(tensor_b_size,)) + ).astype(getattr(np, args.element_b)) +else: + tensor_B = np.random.uniform( + low=-2, high=2, size=(tensor_b_size,) + ).astype(getattr(np, args.element_b)) + +if args.element_c != "int8": + if args.bias: + if args.layout_c == "RowMajor": + tensor_c_size = args.batch * problem_size.n() + elif args.layout_c == "ColumnMajor": + tensor_c_size = args.batch * problem_size.m() + else: + raise ValueError(args.layout_c) + else: + tensor_c_size = args.batch * problem_size.m() * problem_size.n() + if args.element_c == "bfloat16": + tensor_C = np.ceil( + np.random.uniform(low=-8.5, high=7.5, size=(tensor_c_size,)) + ).astype(bfloat16) + else: + tensor_C = np.ceil( + np.random.uniform(low=-8.5, high=7.5, size=(tensor_c_size,)) + ).astype(getattr(np, args.element_c)) +else: + tensor_C = np.random.uniform( + low=-2, high=2, size=(args.batch * problem_size.m() * problem_size.n(),) + ).astype(getattr(np, args.element_c)) + +tensor_D = np.zeros( + shape=(args.batch * problem_size.m() * problem_size.n(),) +).astype(getattr(np, args.element_c)) + +output_op = operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args)) + +arguments = GemmArguments( + operation=operation, problem_size=problem_size, + A=tensor_A, B=tensor_B, C=tensor_C, D=tensor_D, + output_op=output_op, + gemm_mode=getattr(cutlass_bindings.gemm.Mode, args.gemm_mode), + split_k_slices=args.split_k_slices, batch=args.batch +) + +if args.gemm_mode == "GemmSplitKParallel": + reduction_arguments = ReductionArguments( + operation=reduction_operation, + problem_size=[problem_size.m(), problem_size.n()], + partitions=args.split_k_slices, workspace=arguments.ptr_D, + destination=tensor_D, source=tensor_C, + output_op=reduction_operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args)), + bias = arguments.bias + ) + +operation.run(arguments) + +if args.gemm_mode == "GemmSplitKParallel": + reduction_operation.run(reduction_arguments) + reduction_arguments.sync() +else: + arguments.sync() + +# run the host reference module +reference = ReferenceModule(A, B, C) +tensor_D_ref = reference.run( + tensor_A, tensor_B, tensor_C, problem_size, args.alpha, args.beta, args.bias, args.batch) + +tensor_D_ref = getattr(pycutlass, args.activation_function).numpy(*([tensor_D_ref,] + args.activation_args)) + +try: + assert np.array_equal(tensor_D, tensor_D_ref) +except: + assert np.allclose(tensor_D, tensor_D_ref, atol=1e-5) +print("Passed.") diff --git a/examples/40_cutlass_py/customizable/gemm_grouped.py b/examples/40_cutlass_py/customizable/gemm_grouped.py new file mode 100644 index 0000000000..ac2adefacb --- /dev/null +++ b/examples/40_cutlass_py/customizable/gemm_grouped.py @@ -0,0 +1,298 @@ +################################################################################ +# +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################ + +import sys +print("This example is deprecated. Please see examples/python for examples of using " + "the CUTLASS Python interface.") +sys.exit(0) + +import numpy as np +import cutlass.backend as pycutlass +from cutlass.backend import * +from cutlass.backend.utils.device import device_cc +import csv + +import argparse + +# parse the arguments +parser = argparse.ArgumentParser( + description="Launch CUTLASS GEMM Grouped kernels from Python") + +# Operation description +# math instruction description +parser.add_argument("-i", "--instruction_shape", + default=[1, 1, 1], nargs=3, type=int, + help="This option describes the size of MMA op") +parser.add_argument("-ta", "--element_a", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], + help='Data type of elements in input tensor A') +parser.add_argument("-tb", "--element_b", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], + help='Data type of elements in input tensor B') +parser.add_argument("-tc", "--element_c", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], + help='Data type of elements in input tensor C and output tensor D') +parser.add_argument("-tacc", "--element_acc", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16', 'int32', 'int8'], + help='Data type of accumulator') +parser.add_argument('-m', "--math", default="multiply_add", + type=str, choices=["multiply_add", "multiply_add_fast_bf16", "multiply_add_fast_f32"], help="math instruction") +parser.add_argument('-op', "--opcode", default="Simt", type=str, + choices=["Simt", 'TensorOp'], help='This option describes whether you want to use tensor \ + cores (TensorOp) or regular SIMT cores (Simt) on GPU SM') +# tile description +parser.add_argument("-b", "--threadblock_shape", + default=[128, 128, 8], nargs=3, type=int, + help="This option describes the tile size a thread block with compute") +parser.add_argument("-s", "--stages", default=4, + type=int, help="Number of pipelines you want to use") +parser.add_argument("-w", "--warp_count", default=[ + 4, 2, 1], nargs=3, type=int, + help="This option describes the number of warps along M, N, and K of the threadblock") +parser.add_argument("-cc", "--compute_capability", default=80, + type=int, help="This option describes CUDA SM architecture number") +# A +parser.add_argument('-la', "--layout_a", default="RowMajor", type=str, choices=[ + "RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"], + help="Memory layout of input tensor A") +parser.add_argument('-aa', '--alignment_a', default=1, + type=int, help="Memory alignment of input tensor A") +# B +parser.add_argument('-lb', "--layout_b", default="RowMajor", type=str, choices=[ + "RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"], + help="Memory layout of input tensor B") +parser.add_argument('-ab', '--alignment_b', default=1, + type=int, help="Memory alignment of input tensor B") +# C +parser.add_argument('-lc', "--layout_c", default="RowMajor", type=str, choices=[ + "RowMajor", "ColumnMajor", "RowMajorInterleaved32", "ColumnMajorInterleaved32"], + help="Memory layout of input tensor C and output tensor D") +parser.add_argument('-ac', '--alignment_c', default=1, + type=int, help="Memory alignment of input tensor C and output tensor D") +# epilogue +parser.add_argument("-te", "--element_epilogue", default="float32", type=str, + choices=['float64', 'float32', 'float16', 'bfloat16'], help='Epilogue datatype') +parser.add_argument("-ep", "--epilogue_functor", default="LinearCombination", + type=str, choices=['LinearCombination', 'FastLinearCombinationClamp', 'LinearCombinationClamp'], + help="This option describes the epilogue part of the kernel") +# swizzling +parser.add_argument("-sw", "--swizzling_functor", default="IdentitySwizzle1", type=str, choices=[ + "IdentitySwizzle1", "IdentitySwizzle2", "IdentitySwizzle4", "IdentitySwizzle8", "HorizontalSwizzle"], + help="This option describes how thread blocks are scheduled on GPU. \ + NOTE: Threadblock swizzling is currently not supported by CUTLASS's grouped kernels. \ + This parameter is passed in at present to match the APIs of other kernels. The parameter \ + is unused within the kernel") +# precompute mode +parser.add_argument("-pm", "--precompute_mode", + default="Device", type=str, choices=["Host", "Device"], + help="Grouped Gemm Scheduing on device only (Device) or using host precompute (Host)") +# arguments +parser.add_argument("-p", "--problem_size_dir", type=str, default="grouped_gemm_problem_size.csv", + help="path to the csv file contains the problem sizes") +parser.add_argument("-alpha", "--alpha", default=1.0, type=float, help="alpha") +parser.add_argument("-beta", "--beta", default=0.0, type=float, help="beta") +parser.add_argument('-bias', '--bias', action='store_true', help="C is bias vector") + +# Activation function +parser.add_argument("-activ", "--activation_function", default="identity", + choices=["identity", "relu", "leaky_relu", "tanh", "sigmoid", "silu", "hardswish", "gelu"], help="activation function") +parser.add_argument("-activ_arg", "--activation_args", default=[], nargs="+", type=float, + help="addition arguments for activation") +parser.add_argument('--print_cuda', action="store_true", + help="print the underlying CUDA kernel") + +try: + args = parser.parse_args() +except: + sys.exit(0) + +cc = device_cc() +if args.compute_capability != cc: + raise Exception(("Parameter --compute-capability of {} " + "does not match that of the device of {}.").format(args.compute_capability, cc)) + +pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32) + +np.random.seed(0) + +element_a = getattr(cutlass_bindings, args.element_a) +element_b = getattr(cutlass_bindings, args.element_b) +element_c = getattr(cutlass_bindings, args.element_c) +element_acc = getattr(cutlass_bindings, args.element_acc) +math_operation = getattr(MathOperation, args.math) +opclass = getattr(cutlass_bindings.OpClass, args.opcode) + +math_inst = MathInstruction( + args.instruction_shape, element_a, element_b, + element_acc, opclass, math_operation +) + +tile_description = TileDescription( + args.threadblock_shape, args.stages, args.warp_count, + math_inst +) + +layout_a = getattr(cutlass_bindings, args.layout_a) +layout_b = getattr(cutlass_bindings, args.layout_b) +layout_c = getattr(cutlass_bindings, args.layout_c) + +A = TensorDescription( + element_a, layout_a, args.alignment_a +) + +B = TensorDescription( + element_b, layout_b, args.alignment_b +) + +C = TensorDescription( + element_c, layout_c, args.alignment_c +) + +element_epilogue = getattr(cutlass_bindings, args.element_epilogue) +if args.activation_function == "identity": + epilogue_functor = getattr(pycutlass, args.epilogue_functor)( + C.element, C.alignment, math_inst.element_accumulator, element_epilogue) +else: + epilogue_functor = getattr(pycutlass, "LinearCombinationGeneric")( + getattr(pycutlass, args.activation_function)(element_epilogue), + C.element, C.alignment, math_inst.element_accumulator, element_epilogue) +swizzling_functor = getattr(cutlass_bindings, args.swizzling_functor) +precompute_mode = getattr(SchedulerMode, args.precompute_mode) + +operation = GemmOperationGrouped( + arch=args.compute_capability, tile_description=tile_description, + A=A, B=B, C=C, + epilogue_functor=epilogue_functor, swizzling_functor=swizzling_functor, + precompute_mode=precompute_mode +) + +if args.print_cuda: + print(operation.rt_module.emit()) + +pycutlass.compiler.add_module([operation, ]) + +reference_module = ReferenceModule(A, B, C) + +# get problems +problem_sizes = [] +with open(args.problem_size_dir) as csv_file: + reader = csv.reader(csv_file) + for row in reader: + problem_sizes.append( + cutlass_bindings.gemm.GemmCoord(int(row[0]), int(row[1]), int(row[2])) + ) + +problem_count = len(problem_sizes) + +tensor_As = [] +tensor_Bs = [] +tensor_Cs = [] +tensor_Ds = [] +problem_sizes_coord = [] +tensor_D_refs = [] + +for problem_size in problem_sizes: + if args.element_a != "int8": + if args.element_a == "bfloat16": + tensor_A = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m() + * problem_size.k(),))).astype(bfloat16) + else: + tensor_A = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.m() + * problem_size.k(),))).astype(getattr(np, args.element_a)) + else: + tensor_A = np.random.uniform(low=-2, high=2, size=(problem_size.m() + * problem_size.k(),)).astype(getattr(np, args.element_a)) + + if args.element_b != "int8": + if args.element_b == "bfloat16": + tensor_B = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.k() + * problem_size.n(),))).astype(bfloat16) + else: + tensor_B = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(problem_size.k() + * problem_size.n(),))).astype(getattr(np, args.element_b)) + else: + tensor_B = np.random.uniform(low=-2, high=2, size=(problem_size.k() + * problem_size.n(),)).astype(getattr(np, args.element_b)) + + if args.element_c != "int8": + if args.bias: + if args.layout_c == "RowMajor": + c_size = problem_size.n() + elif args.layout_c == "ColumnMajor": + c_size = problem_size.m() + else: + raise ValueError(args.layout_c) + else: + c_size = problem_size.m() * problem_size.n() + if args.element_c == "bfloat16": + tensor_C = np.ceil( + np.random.uniform(low=-8.5, high=7.5, size=(c_size,)) + ).astype(bfloat16) + else: + tensor_C = np.ceil( + np.random.uniform(low=-8.5, high=7.5, size=(c_size,)) + ).astype(getattr(np, args.element_c)) + else: + tensor_C = np.random.uniform( + low=-2, high=2, size=(problem_size.m() * problem_size.n(),) + ).astype(getattr(np, args.element_c)) + tensor_D = np.zeros( + shape=(problem_size.m() * problem_size.n(),) + ).astype(getattr(np, args.element_c)) + + tensor_As.append(tensor_A) + tensor_Bs.append(tensor_B) + tensor_Cs.append(tensor_C) + tensor_Ds.append(tensor_D) + tensor_D_ref = reference_module.run( + tensor_A, tensor_B, tensor_C, problem_size, + args.alpha, args.beta, args.bias) + tensor_D_ref = getattr(pycutlass, args.activation_function).numpy(*([tensor_D_ref,] + args.activation_args)) + tensor_D_refs.append(tensor_D_ref) + problem_sizes_coord.append(problem_size) + +arguments = GemmGroupedArguments( + operation, problem_sizes_coord, tensor_As, tensor_Bs, tensor_Cs, tensor_Ds, + output_op=operation.epilogue_type(*([args.alpha, args.beta] + args.activation_args)) +) + +operation.run(arguments) + +arguments.sync() + +for tensor_d, tensor_d_ref in zip(tensor_Ds, tensor_D_refs): + try: + assert np.array_equal(tensor_d, tensor_d_ref) + except: + assert np.allclose(tensor_d, tensor_d_ref, rtol=1e-5) + +print("Passed.") diff --git a/examples/40_cutlass_py/customizable/grouped_gemm_problem_size.csv b/examples/40_cutlass_py/customizable/grouped_gemm_problem_size.csv new file mode 100644 index 0000000000..d1d0dd00b2 --- /dev/null +++ b/examples/40_cutlass_py/customizable/grouped_gemm_problem_size.csv @@ -0,0 +1,3 @@ +128,128,128 +128,128,256 +512,128,384 diff --git a/examples/40_cutlass_py/gemm.py b/examples/40_cutlass_py/gemm.py new file mode 100644 index 0000000000..076f758287 --- /dev/null +++ b/examples/40_cutlass_py/gemm.py @@ -0,0 +1,153 @@ +################################################################################ +# +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################ +""" +Basic example of using the CUTLASS Python interface to run a GEMM +""" + +import sys +print("This example is deprecated. Please see examples/python for examples of using " + "the CUTLASS Python interface.") +sys.exit(0) + +import argparse +import numpy as np + +import cutlass_bindings +import cutlass.backend as pycutlass +from cutlass.backend import * +from cutlass.backend.utils.device import device_cc + + +parser = argparse.ArgumentParser(description="Launch a GEMM kernel from Python: 'D = alpha * A * B + beta * C'") +parser.add_argument("--m", default=128, type=int, help="M dimension of the GEMM") +parser.add_argument("--n", default=128, type=int, help="N dimension of the GEMM") +parser.add_argument("--k", default=128, type=int, help="K dimension of the GEMM") +parser.add_argument('--print_cuda', action="store_true", help="Print the underlying CUDA kernel") + +try: + args = parser.parse_args() +except: + sys.exit(0) + +# Check that the device is of a sufficient compute capability +cc = device_cc() +assert cc >= 70, "The CUTLASS Python GEMM example requires compute capability greater than or equal to 70." + +alignment = 8 +assert args.m % alignment == 0, "M dimension of size {} is not divisible by alignment of {}".format(args.m, alignment) +assert args.n % alignment == 0, "N dimension of size {} is not divisible by alignment of {}".format(args.n, alignment) +assert args.k % alignment == 0, "K dimension of size {} is not divisible by alignment of {}".format(args.k, alignment) + +np.random.seed(0) + +# Allocate a pool of device memory to be used by the kernel +pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32) + +# Set the compiler to use to NVCC +pycutlass.compiler.nvcc() + +# Set up A, B, C and accumulator +A = TensorDescription(cutlass_bindings.float16, cutlass_bindings.ColumnMajor, alignment) +B = TensorDescription(cutlass_bindings.float16, cutlass_bindings.RowMajor, alignment) +C = TensorDescription(cutlass_bindings.float32, cutlass_bindings.ColumnMajor, alignment) +element_acc = cutlass_bindings.float32 +element_epilogue = cutlass_bindings.float32 + +# Select instruction shape based on the Tensor Core instructions supported +# by the device on which we are running +if cc == 70: + instruction_shape = [8, 8, 4] +elif cc == 75: + instruction_shape = [16, 8, 8] +else: + # Use CUTLASS kernels for CC 80 by default (e.g., for cases in which SM86 is used) + cc = 80 + instruction_shape = [16, 8, 16] + +math_inst = MathInstruction( + instruction_shape, + A.element, B.element, element_acc, + cutlass_bindings.OpClass.TensorOp, + MathOperation.multiply_add +) + +tile_description = TileDescription( + [128, 128, 32], # Threadblock shape + 2, # Number of stages + [2, 2, 1], # Number of warps within each dimension of the threadblock shape + math_inst +) + +epilogue_functor = pycutlass.LinearCombination(C.element, C.alignment, element_acc, element_epilogue) + +operation = GemmOperationUniversal( + arch=cc, tile_description=tile_description, + A=A, B=B, C=C, + epilogue_functor=epilogue_functor) + +if args.print_cuda: + print(operation.rt_module.emit()) + +operations = [operation, ] + +# Compile the operation +pycutlass.compiler.add_module(operations) + +# Randomly initialize tensors +tensor_A = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(args.m * args.k,))).astype(np.float16) +tensor_B = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(args.k * args.n,))).astype(np.float16) +tensor_C = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(args.m * args.n,))).astype(np.float32) +tensor_D = np.zeros(shape=(args.m * args.n,)).astype(np.float32) + +problem_size = cutlass_bindings.gemm.GemmCoord(args.m, args.n, args.k) +alpha = 1. +beta = 0. + +arguments = GemmArguments( + operation=operation, problem_size=problem_size, + A=tensor_A, B=tensor_B, C=tensor_C, D=tensor_D, + output_op=operation.epilogue_type(alpha, beta)) + +# Run the operation +operation.run(arguments) +arguments.sync() + +# Run the host reference module and compare to the CUTLASS result +reference = ReferenceModule(A, B, C) +tensor_D_ref = reference.run(tensor_A, tensor_B, tensor_C, problem_size, alpha, beta) + +try: + assert np.array_equal(tensor_D, tensor_D_ref) +except: + assert np.allclose(tensor_D, tensor_D_ref, atol=1e-5) + +print("Passed.") diff --git a/examples/40_cutlass_py/gemm_grouped.py b/examples/40_cutlass_py/gemm_grouped.py new file mode 100644 index 0000000000..9ba2fa313a --- /dev/null +++ b/examples/40_cutlass_py/gemm_grouped.py @@ -0,0 +1,172 @@ +################################################################################ +# +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################ +""" +Basic example of using the CUTLASS Python interface to run a grouped GEMM +""" + +import sys +print("This example is deprecated. Please see examples/python for examples of using " + "the CUTLASS Python interface.") +sys.exit(0) + +import argparse +import numpy as np + +import cutlass_bindings +import cutlass.backend as pycutlass +from cutlass.backend import * +from cutlass.backend.utils.device import device_cc + + +parser = argparse.ArgumentParser(description="Launch a grouped GEMM kernel from Python") +parser.add_argument('--print_cuda', action="store_true", help="Print the underlying CUDA kernel") + +try: + args = parser.parse_args() +except: + sys.exit(0) + +# Check that the device is of a sufficient compute capability +cc = device_cc() +assert cc >= 70, "The CUTLASS Python grouped GEMM example requires compute capability greater than or equal to 70." + +np.random.seed(0) + +# Allocate a pool of device memory to be used by the kernel +pycutlass.get_memory_pool(init_pool_size=2**30, max_pool_size=2**32) + +# Set the compiler to use to NVCC +pycutlass.compiler.nvcc() + +# Set up A, B, C and accumulator +alignment = 1 +A = TensorDescription(cutlass_bindings.float16, cutlass_bindings.ColumnMajor, alignment) +B = TensorDescription(cutlass_bindings.float16, cutlass_bindings.RowMajor, alignment) +C = TensorDescription(cutlass_bindings.float32, cutlass_bindings.ColumnMajor, alignment) +element_acc = cutlass_bindings.float32 +element_epilogue = cutlass_bindings.float32 + +# Select instruction shape based on the Tensor Core instructions supported +# by the device on which we are running +if cc == 70: + instruction_shape = [8, 8, 4] +elif cc == 75: + instruction_shape = [16, 8, 8] +else: + # Use CUTLASS kernels for CC 80 by default (e.g., for cases in which SM86 is used) + cc = 80 + instruction_shape = [16, 8, 16] + +math_inst = MathInstruction( + instruction_shape, + A.element, B.element, element_acc, + cutlass_bindings.OpClass.TensorOp, + MathOperation.multiply_add +) + +tile_description = TileDescription( + [128, 128, 32], # Threadblock shape + 2, # Number of stages + [2, 2, 1], # Number of warps within each dimension of the threadblock shape + math_inst +) + +epilogue_functor = pycutlass.LinearCombination(C.element, C.alignment, element_acc, element_epilogue) + +operation = GemmOperationGrouped( + arch=cc, tile_description=tile_description, + A=A, B=B, C=C, + epilogue_functor=epilogue_functor, + precompute_mode=SchedulerMode.Device) + +if args.print_cuda: + print(operation.rt_module.emit()) + +operations = [operation, ] + +# Compile the operation +pycutlass.compiler.add_module(operations) + +# Initialize tensors for each problem in the group +problem_sizes = [ + cutlass_bindings.gemm.GemmCoord(128, 128, 64), + cutlass_bindings.gemm.GemmCoord(512, 256, 128) +] +problem_count = len(problem_sizes) + +alpha = 1. +beta = 0. + +tensor_As = [] +tensor_Bs = [] +tensor_Cs = [] +tensor_Ds = [] +tensor_D_refs = [] + +reference = ReferenceModule(A, B, C) + +for problem_size in problem_sizes: + # Randomly initialize tensors + m = problem_size.m() + n = problem_size.n() + k = problem_size.k() + tensor_A = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(m * k,))).astype(np.float16) + tensor_B = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(k * n,))).astype(np.float16) + tensor_C = np.ceil(np.random.uniform(low=-8.5, high=7.5, size=(m * n,))).astype(np.float32) + tensor_D = np.zeros(shape=(m * n,)).astype(np.float32) + + tensor_As.append(tensor_A) + tensor_Bs.append(tensor_B) + tensor_Cs.append(tensor_C) + tensor_Ds.append(tensor_D) + + # Run the reference GEMM + tensor_D_ref = reference.run(tensor_A, tensor_B, tensor_C, problem_size, alpha, beta) + tensor_D_refs.append(tensor_D_ref) + +arguments = GemmGroupedArguments( + operation, problem_sizes, tensor_As, tensor_Bs, tensor_Cs, tensor_Ds, + output_op=operation.epilogue_type(alpha, beta) +) + +# Run the operation +operation.run(arguments) +arguments.sync() + +# Compare the CUTLASS result to the host reference result +for tensor_d, tensor_d_ref in zip(tensor_Ds, tensor_D_refs): + try: + assert np.array_equal(tensor_d, tensor_d_ref) + except: + assert np.allclose(tensor_d, tensor_d_ref, rtol=1e-5) + +print("Passed.") diff --git a/examples/40_cutlass_py/test-cutlass-py.py b/examples/40_cutlass_py/test-cutlass-py.py deleted file mode 100644 index e1ee636b4d..0000000000 --- a/examples/40_cutlass_py/test-cutlass-py.py +++ /dev/null @@ -1,169 +0,0 @@ - -# System modules -import numpy as np -import os.path -import sys -import ctypes - -# CUDA Python modules -from cuda import cuda -from cuda import nvrtc - -# CUTLASS modules -import library -import manifest as cutlass_manifest -import generator -import rt - - -# -# Construct an SGEMM -# - -manifest = cutlass_manifest.Manifest() - -generator.GenerateSM50_Simt(manifest, "11.5.0") - -# -# Construct a GEMM operation -# - -operation = manifest.operations_by_name['cutlass_simt_sgemm_128x128_8x2_nt_align1'] - -# -# Construct a runtime GEMM operation -# -gemm = rt.Gemm(operation) - -# -# Initialize context -# -err, = cuda.cuInit(0) - -if err != cuda.CUresult.CUDA_SUCCESS: - raise RuntimeError("CUDA Error %s" % str(err)) - -err, device = cuda.cuDeviceGet(0) - -if err != cuda.CUresult.CUDA_SUCCESS: - raise RuntimeError("CUDA Error %s" % str(err)) - -err, context = cuda.cuCtxCreate(0, device) - -if err != cuda.CUresult.CUDA_SUCCESS: - raise RuntimeError("CUDA Error %s" % str(err)) - -# -# Construct a module -# - -architectures = [80,] -include_paths = [ - '../../include', - '../../tools/util/include', -] - -compilation_options = rt.CompilationOptions(architectures, include_paths) - -module = rt.Module('module.cu', [gemm], compilation_options) - -# -# Setup a workspace -# - -M, N, K = (128, 128, 128) - -tensor_A = np.ndarray(M * K, dtype=np.float32) -tensor_B = np.ndarray(N * K, dtype=np.float32) -tensor_C = np.ndarray(M * N, dtype=np.float32) -tensor_D = np.ndarray(M * N, dtype=np.float32) - -err, tensor_A_d = cuda.cuMemAlloc(tensor_A.size * tensor_A.itemsize) -if err != cuda.CUresult.CUDA_SUCCESS: - raise RuntimeError("CUDA Error %s" % str(err)) - -err, tensor_B_d = cuda.cuMemAlloc(tensor_B.size * tensor_B.itemsize) -if err != cuda.CUresult.CUDA_SUCCESS: - raise RuntimeError("CUDA Error %s" % str(err)) - -err, tensor_C_d = cuda.cuMemAlloc(tensor_C.size * tensor_C.itemsize) -if err != cuda.CUresult.CUDA_SUCCESS: - raise RuntimeError("CUDA Error %s" % str(err)) - -err, tensor_D_d = cuda.cuMemAlloc(tensor_D.size * tensor_D.itemsize) -if err != cuda.CUresult.CUDA_SUCCESS: - raise RuntimeError("CUDA Error %s" % str(err)) - -err, stream = cuda.cuStreamCreate(0) -if err != cuda.CUresult.CUDA_SUCCESS: - raise RuntimeError("CUDA Error %s" % str(err)) - -tensors = [ - (tensor_A_d, tensor_A), - (tensor_B_d, tensor_B), - (tensor_C_d, tensor_C), - (tensor_D_d, tensor_D) -] - -for tensor_device, tensor_host in tensors: - bytes = tensor_host.size * tensor_host.itemsize - print("Tensor has dimensions: %s (%d bytes)" % (str(tensor_host.size), tensor_host.itemsize)) - err, = cuda.cuMemcpyHtoDAsync(tensor_device, tensor_host, bytes, stream) - print("updating tensor in device memory ", hex(int(tensor_device))) - if err != cuda.CUresult.CUDA_SUCCESS: - raise RuntimeError('CUDA Error %s' % str(err)) - -# -# Initialize a host buffer -# - -arguments = rt.GemmArguments() - -arguments.problem_size = rt.GemmCoord(M, N, K) - -arguments.A = rt.TensorRef(tensor_A_d, M) -arguments.B = rt.TensorRef(tensor_B_d, N) -arguments.C = rt.TensorRef(tensor_C_d, M) -arguments.D = rt.TensorRef(tensor_D_d, M) - -host_workspace = bytearray(gemm.get_host_workspace_size(arguments)) -device_workspace = None - -launch_config = gemm.plan(arguments) - -byte_count = gemm.initialize(host_workspace, device_workspace, launch_config, arguments) - -# -# Launch the kernel -# - -err = gemm.run(host_workspace, device_workspace, launch_config) - -if err != cuda.CUresult.CUDA_SUCCESS: - raise RuntimeError('CUDA Error %s' % str(err)) - -# -# Verify results -# -err, = cuda.cuStreamSynchronize(stream) - -if err != cuda.CUresult.CUDA_SUCCESS: - raise RuntimeError("CUDA Error %s" % str(err)) - - -# -# Debug reporting of byte array contents -# - -def PrintBytearray(host_workspace): - uint_str = None - prefix = None - print("uint32_t host_workspace[] = {") - for idx, byte in enumerate(host_workspace): - if not (idx % 4): - if uint_str is not None: - print(prefix, uint_str, ",") - prefix = "/* offset: %d B */ 0x" % idx - uint_str = "" - uint_str = "{:02x}".format(byte) + uint_str - print("};") diff --git a/examples/41_fused_multi_head_attention/CMakeLists.txt b/examples/41_fused_multi_head_attention/CMakeLists.txt new file mode 100644 index 0000000000..8ed6227010 --- /dev/null +++ b/examples/41_fused_multi_head_attention/CMakeLists.txt @@ -0,0 +1,56 @@ + +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +cutlass_example_add_executable( + 41_fused_multi_head_attention_fixed_seqlen + fused_multihead_attention_fixed_seqlen.cu + ) + +cutlass_example_add_executable( + 41_fused_multi_head_attention_variable_seqlen + fused_multihead_attention_variable_seqlen.cu + ) + +cutlass_example_add_executable( + 41_fused_multi_head_attention_backward + fused_multi_head_attention_backward.cu + DISABLE_TESTS ON + ) + + +add_custom_target(41_fused_multi_head_attention +DEPENDS 41_fused_multi_head_attention_fixed_seqlen + 41_fused_multi_head_attention_variable_seqlen + 41_fused_multi_head_attention_backward +) + +add_test( + NAME ctest_examples_41_fmha_backward_python + COMMAND ${Python3_EXECUTABLE} ${CMAKE_CURRENT_LIST_DIR}/fmha_backward_test.py $ +) diff --git a/examples/41_fused_multi_head_attention/debug_utils.h b/examples/41_fused_multi_head_attention/debug_utils.h new file mode 100644 index 0000000000..efca4f132d --- /dev/null +++ b/examples/41_fused_multi_head_attention/debug_utils.h @@ -0,0 +1,234 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once +#include +#include +#include + +//////////////////////////////////////////////////////////////////////////////// +// Debugging functions +//////////////////////////////////////////////////////////////////////////////// +// Nans & inf detection +#define NANCHECK(frag) \ + { \ + for (size_t _i = 0; _i < frag.size(); ++_i) { \ + assert(std::isfinite(float(frag[_i]))); \ + assert(!std::isnan(float(frag[_i]))); \ + } \ + } + +// Print on the first thread of the first block +#if 1 +#define PRINT_WARP_ID 0 +#define PRINT_LANE_ID 0 +#define PRINT_B0_T0(msg, ...) \ + if (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && \ + threadIdx.x == PRINT_LANE_ID && threadIdx.y == PRINT_WARP_ID && \ + threadIdx.z == 0) { \ + printf(msg "\n", ##__VA_ARGS__); \ + } +#define PRINT_T0(msg, ...) \ + if (threadIdx.x == PRINT_LANE_ID && threadIdx.y == PRINT_WARP_ID && \ + threadIdx.z == 0) { \ + printf(msg "\n", ##__VA_ARGS__); \ + } +#define PRINT_TX_LX(msg, ...) \ + for (int bx = 0; bx < gridDim.x; ++bx) { \ + for (int by = 0; by < gridDim.y; ++by) { \ + for (int bz = 0; bz < gridDim.z; ++bz) { \ + for (int tx = 0; tx < blockDim.x; ++tx) { \ + for (int ty = 0; ty < blockDim.y; ++ty) { \ + for (int tz = 0; tz < blockDim.z; ++tz) { \ + __syncthreads(); \ + if (blockIdx.x == bx && blockIdx.y == by && blockIdx.z == bz && \ + threadIdx.x == tx && threadIdx.y == ty && \ + threadIdx.z == tz) { \ + printf( \ + "[%d,%d,%d][%d,%d,%d]" msg "\n", \ + bx, \ + by, \ + bz, \ + tx, \ + ty, \ + tz, \ + ##__VA_ARGS__); \ + } \ + } \ + } \ + } \ + } \ + } \ + } +#else +#define PRINT_B0_T0 +#define PRINT_TX_LX +#endif + +struct __string_view { + char const* data; + std::size_t size; +}; +#if __cplusplus >= 201402L +template +constexpr __string_view __get_type_name() { + char const* p = __PRETTY_FUNCTION__; + while (*p++ != '=') + ; + for (; *p == ' '; ++p) + ; + char const* p2 = p; + int count = 1; + for (;; ++p2) { + switch (*p2) { + case '[': + ++count; + break; + case ']': + --count; + if (!count) + return {p, std::size_t(p2 - p)}; + } + } + return {}; +} +#else +template +constexpr __string_view __get_type_name() { + return {"unsupported", 11}; +} +#endif + +// Print a given array +#define PRINT_ACCUM8_T0_L0_START(name, accum, start) \ + PRINT_B0_T0( \ + "%s[%d:%d] - {%f, %f, %f, %f, %f, %f, %f, %f}", \ + name, \ + int(start), \ + int(start + 8), \ + float(accum[start + 0]), \ + float(accum[start + 1]), \ + float(accum[start + 2]), \ + float(accum[start + 3]), \ + float(accum[start + 4]), \ + float(accum[start + 5]), \ + float(accum[start + 6]), \ + float(accum[start + 7])); +#define PRINT_ACCUM8_T0_L0(name, accum) PRINT_ACCUM8_T0_L0_START(name, accum, 0) +#define PRINT_FRAG_T0_L0(name, frag) \ + { \ + auto typeStr = __get_type_name(); \ + PRINT_B0_T0("printing %s (%s)", name, typeStr.data); \ + for (size_t _start = 0; _start < frag.size(); _start += 8) { \ + PRINT_ACCUM8_T0_L0_START(" ", frag, _start); \ + } \ + /*__syncthreads(); \ + NANCHECK(frag); */ \ + } +#define PRINT_ARRAY_T0_L0_INCR(name, array, length, incr) \ + { \ + PRINT_B0_T0("printing %s (len=%d)", name, int(length)); \ + for (int _start = 0; _start < length; _start += incr) { \ + PRINT_ACCUM8_T0_L0_START(" ", array, _start); \ + } \ + } +#define PRINT_ARRAY_T0_L0(name, array, length) \ + PRINT_ARRAY_T0_L0_INCR(name, array, length, 8) + +// Print a 4x4 matrix +#define PRINT_TENSOR4x4_T0_L0_START(name, ref, start_x, start_y) \ + PRINT_B0_T0( \ + "%s[%d:%d, %d:%d]:\n %f, %f, %f, %f\n %f, %f, %f, %f\n %f, %f, %f, %f\n %f, %f, %f, %f", \ + name, \ + int(start_x), \ + int(start_x + 4), \ + int(start_y), \ + int(start_y + 4), \ + float(ref.at({start_x + 0, start_y + 0})), \ + float(ref.at({start_x + 0, start_y + 1})), \ + float(ref.at({start_x + 0, start_y + 2})), \ + float(ref.at({start_x + 0, start_y + 3})), \ + float(ref.at({start_x + 1, start_y + 0})), \ + float(ref.at({start_x + 1, start_y + 1})), \ + float(ref.at({start_x + 1, start_y + 2})), \ + float(ref.at({start_x + 1, start_y + 3})), \ + float(ref.at({start_x + 2, start_y + 0})), \ + float(ref.at({start_x + 2, start_y + 1})), \ + float(ref.at({start_x + 2, start_y + 2})), \ + float(ref.at({start_x + 2, start_y + 3})), \ + float(ref.at({start_x + 3, start_y + 0})), \ + float(ref.at({start_x + 3, start_y + 1})), \ + float(ref.at({start_x + 3, start_y + 2})), \ + float(ref.at({start_x + 3, start_y + 3}))); +#define PRINT_TENSOR4x4_T0_L0(name, ref) \ + PRINT_TENSOR4x4_T0_L0_START(name, ref, 0, 0) + +#define PRINT_PROBLEM_SIZE(name, ps) \ + PRINT_B0_T0( \ + "%s.problem_size: {.m=%d, .n=%d, .k=%d}", \ + name, \ + int(ps.m()), \ + int(ps.n()), \ + int(ps.k())) + +template +CUTLASS_DEVICE void print_warp_accum( + AccumT accum, + LaneOffsetT lane_offset, + int32_t num_rows, + int32_t num_cols) { + bool is_main = blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0 && + threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0; + for (int row = 0; row < num_rows; ++row) { + for (int col = 0; col < num_cols; ++col) { + if (col % 32 == 0) { + if (is_main) { + printf("\nmat[%3d, %3d:%3d]", row, col, col + 32); + } + __syncthreads(); + } + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) {}, + [&](int accum_m, int accum_n, int idx) { + if (row == accum_m && col == accum_n && + (blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0)) { + printf(" %6.1f", float(accum[idx])); + } + }, + [&](int accum_m) {}); + __syncthreads(); + } + if (is_main) { + printf("\n"); + } + } +} diff --git a/examples/41_fused_multi_head_attention/default_fmha_grouped.h b/examples/41_fused_multi_head_attention/default_fmha_grouped.h new file mode 100644 index 0000000000..54e537c9ca --- /dev/null +++ b/examples/41_fused_multi_head_attention/default_fmha_grouped.h @@ -0,0 +1,299 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level GEMM definitions combine threadblock-scoped matrix multiply-add with + the appropriate threadblock-scoped epilogue. + + Note, CUTLASS epilogues universally target row-major outputs. Column-major outputs are + accommodated by exchanging A and B operands and assuming transposed layouts. Partial + specializations here choose 'device::GemmTransposed' to implement this functionality. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/complex.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" + +#include "fmha_grouped.h" +#include "gemm_kernel_utils.h" +#include "gemm/custom_mma.h" +#include "gemm/find_default_mma.h" +#include "gemm/mma_from_smem.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + // The datatype of Q/K/V + typename scalar_t_, + // Architecture we are targeting (eg `cutlass::arch::Sm80`) + typename ArchTag_, + // If Q/K/V are correctly aligned in memory and we can run a fast kernel + bool isAligned_, + int kQueriesPerBlock, + int kKeysPerBlock, + int kMaxK = (int)cutlass::platform::numeric_limits::max(), + GroupScheduleMode GroupScheduleMode_ = GroupScheduleMode::kDeviceOnly + > +struct DefaultFMHAGrouped { + using scalar_t = scalar_t_; + using accum_t = float; + using output_t = scalar_t; + + // Accumulator between 2 iterations + // Using `accum_t` improves perf on f16 at the cost of + // numerical errors + using output_accum_t = accum_t; + + using ArchTag = ArchTag_; + static bool const kIsAligned = isAligned_; + static bool const kSingleValueIteration = kMaxK <= kKeysPerBlock; + static constexpr bool kIsHalf = cutlass::sizeof_bits::value == 16; + static int const kWarpSize = 32; + static int const kNumWarpsPerBlock = kQueriesPerBlock * kKeysPerBlock / (kWarpSize * kWarpSize); + + struct MM0 { + /* + In this first matmul, we compute a block of `Q @ K.T`. + While the calculation result is still hot in registers, we update + `mi`, `m_prime`, `s_prime` in shared-memory, and then store this value + into a shared-memory ("AccumulatorSharedStorage") that is used later as + operand A for the second matmul (see MM1) + */ + + using GemmType = gemm_kernel_utils::DefaultGemmType; + using OpClass = typename GemmType::OpClass; + + using ElementA = scalar_t; + using ElementB = scalar_t; + using ElementC = scalar_t; + using ElementAccumulator = accum_t; + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + + using DefaultConfig = + typename cutlass::gemm::device::DefaultGemmConfiguration< + OpClass, + ArchTag, + ElementA, + ElementB, + ElementC, + ElementAccumulator + >; + + static int const kAlignmentA = + kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment; + static int const kAlignmentB = + kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment; + + using ThreadblockShape = cutlass::gemm::GemmShape; + using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; + using InstructionShape = typename GemmType::InstructionShape; + + static int const kStages = DefaultConfig::kStages; + using Operator = typename GemmType::Operator; + + using DefaultMma = typename cutlass::gemm::threadblock::FindDefaultMma< + ElementA, + LayoutA, + kAlignmentA, + ElementB, + LayoutB, + kAlignmentB, + ElementAccumulator, + LayoutC, + OpClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + ArchTag::kMinComputeCapability >= 80 && kIsHalf + ? 4 + : DefaultConfig::kStages, + Operator + >::DefaultMma; + + using MmaCore = typename DefaultMma::MmaCore; + using IteratorA = typename DefaultMma::IteratorA; + using IteratorB = typename DefaultMma::IteratorB; + using DefaultThreadblockMma = typename DefaultMma::ThreadblockMma; + using Mma = typename cutlass::platform::conditional< + kSingleValueIteration, + typename MakeCustomMma::Mma, + DefaultThreadblockMma>::type; + using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator< + typename Mma::Operator::IteratorC, + ElementAccumulator, + kWarpSize>::Iterator; + + static_assert(MmaCore::WarpCount::kCount == kNumWarpsPerBlock, ""); + + // Epilogue to store to shared-memory in a format that we can use later for + // the second matmul + using B2bGemm = typename cutlass::gemm::threadblock::B2bGemm< + typename Mma::Operator::IteratorC, + typename Mma::Operator, + scalar_t, + WarpShape, + ThreadblockShape>; + using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage; + }; + + struct MM1 { + /* + Second matmul: perform `attn @ V` where `attn` is the attention (not + normalized) and stored in shared memory + */ + + using GemmType = typename MM0::GemmType; + using OpClass = typename GemmType::OpClass; + + using ElementA = scalar_t; + using ElementB = scalar_t; + using ElementC = output_accum_t; + using ElementAccumulator = accum_t; + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::RowMajor; + using LayoutC = cutlass::layout::RowMajor; + + using DefaultConfig = + typename cutlass::gemm::device::DefaultGemmConfiguration< + OpClass, + ArchTag, + ElementA, + ElementB, + ElementC, + ElementAccumulator + >; + + static int const kAlignmentA = DefaultConfig::kAlignmentA; + static int const kAlignmentB = + kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment; + + using ThreadblockShape = typename MM0::ThreadblockShape; + using WarpShape = typename MM0::WarpShape; + using InstructionShape = typename MM0::InstructionShape; + + using EpilogueOutputOp = typename DefaultConfig::EpilogueOutputOp; + + static int const kStages = DefaultConfig::kStages; + using Operator = typename GemmType::Operator; + + using ThreadblockSwizzle = void; // Swizzling is unused + static bool const kSplitKSerial = false; + + using DefaultGemm = cutlass::gemm::kernel::DefaultGemm< + ElementA, + LayoutA, + kAlignmentA, + ElementB, + LayoutB, + kAlignmentB, + ElementC, + LayoutC, + ElementAccumulator, + OpClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + ArchTag::kMinComputeCapability >= 80 && kIsHalf + ? 4 + : DefaultConfig::kStages, + kSplitKSerial, + Operator>; + + using WarpIteratorA = typename cutlass::gemm::threadblock:: + DefaultWarpIteratorAFromSharedMemory< + typename DefaultGemm::Mma::Policy::Operator::Shape, // WarpShape + typename DefaultGemm::Mma::Policy::Operator::InstructionShape, + typename DefaultGemm::Mma::Policy::Operator::IteratorA, + typename DefaultGemm::Mma::Policy>::WarpIterator; + + using DefaultMmaFromSmem = + typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< + typename DefaultGemm::Mma, + MM0::AccumulatorSharedStorage::Shape::kN, // kMaxK + WarpIteratorA, + false>; // kScaleOperandA + + using Mma = typename DefaultMmaFromSmem::Mma; + using IteratorB = typename Mma::IteratorB; + using WarpCount = typename Mma::WarpCount; + static_assert(WarpCount::kCount == kNumWarpsPerBlock, ""); + + using DefaultEpilogue = typename DefaultGemm::Epilogue; + using OutputTileIterator = + typename cutlass::epilogue::threadblock::PredicatedTileIterator< + typename DefaultEpilogue::OutputTileIterator::ThreadMap, + output_t>; + using OutputTileIteratorAccum = + typename cutlass::epilogue::threadblock::PredicatedTileIterator< + typename DefaultEpilogue::OutputTileIterator::ThreadMap, + output_accum_t>; + }; + +/// Define the kernel in terms of the default kernel + using FMHAKernel = kernel::FMHAGrouped< + MM0, + MM1, + scalar_t, + accum_t, + output_t, + output_accum_t, + kSingleValueIteration, + GroupScheduleMode_ + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/41_fused_multi_head_attention/epilogue/epilogue_pipelined.h b/examples/41_fused_multi_head_attention/epilogue/epilogue_pipelined.h new file mode 100644 index 0000000000..e166af4de4 --- /dev/null +++ b/examples/41_fused_multi_head_attention/epilogue/epilogue_pipelined.h @@ -0,0 +1,624 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + File copied from "cutlass/epilogue/threadblock/epilogue.h" + then modified to: + (1) load 2 source fragments at the same time (pipelining) + (2) support reading from a different dtype + (3) pass the row id to the OutputOp if it takes it + (see MemoryEfficientAttentionNormalize) + Note that in general the fragment passed to the OutputOp could + span multiple rows but it does not happen with the configurations we have +*/ + +#pragma once + +#include + +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/functional.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/vector.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/transform/threadblock/regular_tile_iterator.h" +#include "cutlass/epilogue/threadblock/epilogue_base.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" +#include "cutlass/numeric_types.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +template +struct ApplyEpilogueOp { + static CUTLASS_DEVICE typename Op::FragmentOutput apply( + Op const& output_op, + int row_id, + typename Op::FragmentAccumulator const& accum, + typename Op::FragmentOutput const& source) { + return output_op(accum, source); + } + static CUTLASS_DEVICE typename Op::FragmentOutput apply( + Op const& output_op, + int row_id, + typename Op::FragmentAccumulator const& accum) { + return output_op(accum); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Epilogue operator +template < + typename Shape_, ///< Shape of threadblock tile (concept: GemmShape) + typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: + ///< gemm::warp::MmaTensorOp) + int PartitionsK, ///< Number of partitions of the K dimension + typename OutputTileIterator_, ///< Tile iterator writing output tensors + typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting + ///< accumulators + typename WarpTileIterator_, ///< Warp-scoped tile iterator writing + ///< accumulators to SMEM + typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading + ///< from SMEM + typename OutputOp_, ///< Output operator + typename Padding_, ///< Padding added to SMEM allocation to avoid bank + ///< conflicts (concept: MatrixShape) + int FragmentsPerPartition = + 1, ///< Used to coarsten the epilogue granularity + int IterationsUnroll = ///< Used to reduce binary size when epilogue op is + ///< large + (!IsEpilogueFunctorHeavy::value), + typename OutputTileSourceIterator_ = + OutputTileIterator_ ///< Tile iterator reading tensors + > +class EpiloguePipelined : public EpilogueBase< + Shape_, + typename WarpMmaOperator_::Shape, + PartitionsK, + AccumulatorFragmentIterator_, + WarpTileIterator_, + Padding_, + FragmentsPerPartition> { + public: + using Base = EpilogueBase< + Shape_, + typename WarpMmaOperator_::Shape, + PartitionsK, + AccumulatorFragmentIterator_, + WarpTileIterator_, + Padding_, + FragmentsPerPartition>; + + using Shape = Shape_; + using WarpMmaOperator = WarpMmaOperator_; + static int const kPartitionsK = PartitionsK; + using OutputTileIterator = OutputTileIterator_; + using OutputTileSourceIterator = OutputTileSourceIterator_; + using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; + using WarpTileIterator = WarpTileIterator_; + using SharedLoadIterator = SharedLoadIterator_; + using OutputOp = OutputOp_; + using Padding = Padding_; + + using Layout = layout::RowMajor; + using LongIndex = typename Layout::LongIndex; + + /// The complete warp-level accumulator tile + using AccumulatorTile = typename Base::AccumulatorTile; + + /// Accumulator element + using ElementAccumulator = typename WarpTileIterator::Element; + + /// Output element + using ElementOutput = typename OutputTileIterator::Element; + using ElementSource = typename OutputTileSourceIterator::Element; + + /// Output access size + static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; + + /// Tensor reference to destination tensor + using TensorRef = typename OutputTileIterator::TensorRef; + + /// Tensor reference to sync tensor + using SyncTensorRef = + typename cutlass::TensorRef; + + /// Const tensor reference to source tensor + using ConstTensorRef = typename OutputTileIterator::ConstTensorRef; + + /// Array type used to output + using OutputAccessType = Array< + typename OutputTileIterator::Element, + OutputTileIterator::kElementsPerAccess>; + using SourceAccessType = Array< + typename OutputTileSourceIterator::Element, + OutputTileSourceIterator::kElementsPerAccess>; + + /// Array type used by output functor + using AccumulatorAccessType = Array< + typename WarpTileIterator::Element, + OutputTileIterator::kElementsPerAccess>; + + /// Number of warps + using WarpCount = typename Base::WarpCount; + + static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1 + ? Base::kFragmentsPerIteration + : kPartitionsK; + static int constexpr kSmemPointerOffset = + Base::SharedStorage::StorageShape::kCount / kSmemTiles; + + public: + static_assert( + OutputTileSourceIterator::Fragment::kElements == + OutputTileIterator::Fragment::kElements, + "Mismatch between input tile and output tile iterator (kElements)"); + static_assert( + OutputTileSourceIterator::kIterations == OutputTileIterator::kIterations, + "Mismatch between input tile and output tile iterator (kIterations)"); + static_assert( + SharedLoadIterator::Fragment::kElements == + OutputTileIterator::Fragment::kElements, + "Mismatch between shared load iterator and output tile iterator."); + + static_assert( + OutputTileIterator::kElementsPerAccess, + "OutputTileIterator::kElementsPerAccess must not be zero."); + + static_assert( + !(OutputTileIterator::Fragment::kElements % + OutputTileIterator::kElementsPerAccess), + "Divisibility"); + + private: + /// Loads fragment from shared memory aligned with output tensor + SharedLoadIterator shared_load_iterator_; + + public: + /// Constructor + CUTLASS_DEVICE + EpiloguePipelined( + typename Base::SharedStorage& shared_storage, ///< Shared storage object + int thread_idx, ///< ID of a thread within the threadblock + int warp_idx, ///< ID of warp within threadblock + int lane_idx ///< Id of thread within warp + ) + : Base(shared_storage, thread_idx, warp_idx, lane_idx), + shared_load_iterator_(shared_storage.reference(), thread_idx) {} + + /// Streams the result to global memory + CUTLASS_DEVICE + void operator()( + OutputOp const& output_op, ///< Output operator + OutputTileIterator + destination_iterator, ///< Tile iterator for destination + AccumulatorTile const& + accumulators, ///< Complete warp-level accumulator tile + OutputTileSourceIterator + source_iterator) { ///< Threadblock tile coordinate in GEMM (in units + ///< of threadblock tiles) + + if (!output_op.is_source_needed()) { + compute_source_not_needed_(output_op, destination_iterator, accumulators); + } else { + compute_source_needed_( + output_op, destination_iterator, accumulators, source_iterator); + } + } + CUTLASS_DEVICE + void operator()( + OutputOp const& output_op, ///< Output operator + OutputTileIterator + destination_iterator, ///< Tile iterator for destination + AccumulatorTile const& + accumulators) { ///< Complete warp-level accumulator tile + compute_source_not_needed_(output_op, destination_iterator, accumulators); + } + + private: + template + struct acc2smem_source_not_needed; + + template + struct acc2smem_source_not_needed> { + template + CUTLASS_DEVICE static void helper( + AccumulatorFragmentIterator accum_fragment_iterator, + WarpTileIterator& warp_tile_iterator) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Advance; i++) { + ++accum_fragment_iterator; + } + + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < Base::kFragmentsPerIteration; ++p) { + typename AccumulatorFragmentIterator::Fragment accum_fragment; + + accum_fragment_iterator.load(accum_fragment); + ++accum_fragment_iterator; + + warp_tile_iterator.store(accum_fragment); + if (p < Base::kFragmentsPerIteration - 1) { + warp_tile_iterator.add_pointer_offset(kSmemPointerOffset); + } + } + + if (Base::kFragmentsPerIteration > 1) { + warp_tile_iterator.add_pointer_offset( + kSmemPointerOffset * (1 - Base::kFragmentsPerIteration)); + } + } + + CUTLASS_DEVICE + static void push( + size_t pos, + AccumulatorFragmentIterator const& iterator_begin, + WarpTileIterator& warp_tile_iterator) { + int dummy[] = { + (pos == (Seq * Base::kFragmentsPerIteration)) && + (helper( + iterator_begin, warp_tile_iterator), + 0)...}; + + CUTLASS_UNUSED(dummy[0]); + } + }; + + static_assert( + kPartitionsK == 1 || Base::kFragmentsPerIteration == 1, + "One of these must be exactly 1."); + + /// Streams the result to global memory + CUTLASS_DEVICE + void compute_source_not_needed_( + OutputOp const& output_op, ///< Output operator + OutputTileIterator + destination_iterator, ///< Tile iterator for destination + AccumulatorTile const& + accumulators ///< Complete warp-level accumulator tile + ) { + // + // Iterator over warp-level accumulator fragment + // + + AccumulatorFragmentIterator accum_fragment_iterator(accumulators); + + // + // Iterate over accumulator tile + // + +#pragma unroll( \ + IterationsUnroll \ + ? OutputTileIterator::kIterations / Base::kFragmentsPerIteration \ + : 1) + for (int iter = 0; iter < OutputTileIterator::kIterations; + iter += Base::kFragmentsPerIteration) { + // + // Convert and store fragment + // + + __syncthreads(); + + acc2smem_source_not_needed>:: + push(iter, accum_fragment_iterator, this->warp_tile_iterator_); + + __syncthreads(); + + // + // Load fragments from shared memory + // + + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < Base::kFragmentsPerIteration; ++p) { + typename SharedLoadIterator::Fragment + aligned_accum_fragment[kPartitionsK]; + + shared_load_iterator_.load(aligned_accum_fragment[0]); + + if (p < Base::kFragmentsPerIteration - 1) { + shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); + } else if (kPartitionsK > 1) { + plus add_fragments; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < kPartitionsK; ++i) { + shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); + shared_load_iterator_.load(aligned_accum_fragment[i]); + aligned_accum_fragment[0] = add_fragments( + aligned_accum_fragment[0], aligned_accum_fragment[i]); + } + + shared_load_iterator_.add_pointer_offset( + (1 - kPartitionsK) * kSmemPointerOffset); + } + + // + // Compute the output result + // + + typename OutputTileIterator::Fragment output_fragment; + + apply_output_operator_source_not_needed_( + destination_iterator.thread_start_row(), + output_fragment, + output_op, + aligned_accum_fragment[0]); + + // + // Store the final result + // + + destination_iterator.store(output_fragment); + ++destination_iterator; + } + + if (Base::kFragmentsPerIteration > 1) { + shared_load_iterator_.add_pointer_offset( + kSmemPointerOffset * (1 - Base::kFragmentsPerIteration)); + } + } + } + + template + struct acc2smem_source_needed; + + template + struct acc2smem_source_needed> { + template + CUTLASS_DEVICE static void helper( + AccumulatorFragmentIterator accum_fragment_iterator, + WarpTileIterator& warp_tile_iterator) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Advance; i++) { + ++accum_fragment_iterator; + } + + typename AccumulatorFragmentIterator::Fragment accum_fragment; + accum_fragment_iterator.load(accum_fragment); + warp_tile_iterator.store(accum_fragment); + } + + CUTLASS_DEVICE + static void push( + size_t pos, + AccumulatorFragmentIterator const& iterator_begin, + WarpTileIterator& warp_tile_iterator) { + int dummy[] = { + (pos == Seq) && + (helper(iterator_begin, warp_tile_iterator), 0)...}; + } + }; + + /// Streams the result to global memory + CUTLASS_DEVICE + void compute_source_needed_( + OutputOp const& output_op, ///< Output operator + OutputTileIterator + destination_iterator, ///< Tile iterator for destination + AccumulatorTile const& + accumulators, ///< Complete warp-level accumulator tile + OutputTileSourceIterator + source_iterator ///< Threadblock tile coordinate in GEMM (in units of + ///< threadblock tiles) + ) { + typename OutputTileSourceIterator::Fragment source_fragment[2]; + + source_fragment[0].clear(); + source_iterator.load(source_fragment[0]); + ++source_iterator; + source_fragment[1].clear(); + + // + // Iterator over warp-level accumulator fragment + // + + AccumulatorFragmentIterator accum_fragment_iterator(accumulators); + + // + // Iterate over accumulator tile + // + +#pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1) + for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { + if (iter > 0) { + __syncthreads(); + } + // + // Load the source for next iteration (pipelining) + // + + if (iter + 1 < OutputTileIterator::kIterations) { + source_iterator.load(source_fragment[(iter + 1) % 2]); + } + ++source_iterator; + acc2smem_source_needed< + cutlass::make_index_sequence>:: + push(iter, accum_fragment_iterator, this->warp_tile_iterator_); + + __syncthreads(); + + // + // Load fragments from shared memory + // + + typename SharedLoadIterator::Fragment + aligned_accum_fragment[kPartitionsK]; + + shared_load_iterator_.load(aligned_accum_fragment[0]); + + // If the number of k-slices is > 1 - perform a reduction amongst the + // k-slices + if (kPartitionsK > 1) { + plus add_fragments; + + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < kPartitionsK; ++i) { + shared_load_iterator_.add_pointer_offset(kSmemPointerOffset); + shared_load_iterator_.load(aligned_accum_fragment[i]); + aligned_accum_fragment[0] = add_fragments( + aligned_accum_fragment[0], aligned_accum_fragment[i]); + } + + shared_load_iterator_.add_pointer_offset( + (1 - kPartitionsK) * kSmemPointerOffset); + } + + // + // Compute the output result + // + + typename OutputTileIterator::Fragment output_fragment; + + apply_output_operator_( + destination_iterator.thread_start_row(), + output_fragment, + output_op, + aligned_accum_fragment[0], + source_fragment[iter % 2]); + + // + // Store the final result + // + + destination_iterator.store(output_fragment); + ++destination_iterator; + } + } + + /// Helper to invoke the output functor over each vector of output + CUTLASS_DEVICE + void apply_output_operator_( + int begin_row, + typename OutputTileIterator::Fragment& output_fragment, + OutputOp const& output_op, ///< Output operator + typename SharedLoadIterator::Fragment const& aligned_accum_fragment, + typename OutputTileSourceIterator::Fragment const& source_fragment) { + OutputAccessType* output_frag_ptr = + reinterpret_cast(&output_fragment); + + AccumulatorAccessType const* compute_frag_ptr = + reinterpret_cast(&aligned_accum_fragment); + + SourceAccessType const* source_frag_ptr = + reinterpret_cast(&source_fragment); + + int const kOutputOpIterations = OutputTileIterator::Fragment::kElements / + OutputTileIterator::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kOutputOpIterations; ++i) { + // Call the output operator + output_frag_ptr[i] = ApplyEpilogueOp::apply( + output_op, + begin_row + getRowOffset(i * OutputTileIterator::kElementsPerAccess), + compute_frag_ptr[i], + source_frag_ptr[i]); + } + } + + /// Helper to invoke the output functor over each vector of output + CUTLASS_DEVICE + void apply_output_operator_source_not_needed_( + int begin_row, + typename OutputTileIterator::Fragment& output_fragment, + OutputOp const& output_op, ///< Output operator + typename SharedLoadIterator::Fragment const& aligned_accum_fragment) { + OutputAccessType* output_frag_ptr = + reinterpret_cast(&output_fragment); + + AccumulatorAccessType const* compute_frag_ptr = + reinterpret_cast(&aligned_accum_fragment); + + int const kOutputOpIterations = OutputTileIterator::Fragment::kElements / + OutputTileIterator::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kOutputOpIterations; ++i) { + // Call the output operator + output_frag_ptr[i] = ApplyEpilogueOp::apply( + output_op, + begin_row + getRowOffset(i * OutputTileIterator::kElementsPerAccess), + compute_frag_ptr[i]); + } + } + + // This should be constexpr, but it's only supported on c++14 + static int CUTLASS_HOST_DEVICE getRowOffset(int i) { + using ThreadMap = typename OutputTileIterator::ThreadMap; + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; + ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + int frag_row_idx = + (row + + ThreadMap::Iterations::kRow * + (group + ThreadMap::Iterations::kGroup * cluster)); + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; + ++column) { + int frag_idx = ThreadMap::kElementsPerAccess * + (frag_row_idx * ThreadMap::Iterations::kColumn + column); + if (i < frag_idx + ThreadMap::kElementsPerAccess) { + return row_offset; + } + } + } + } + } + return -1; + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/examples/41_fused_multi_head_attention/epilogue/epilogue_rescale_output.h b/examples/41_fused_multi_head_attention/epilogue/epilogue_rescale_output.h new file mode 100644 index 0000000000..6860ee9e4c --- /dev/null +++ b/examples/41_fused_multi_head_attention/epilogue/epilogue_rescale_output.h @@ -0,0 +1,254 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + The epilogue rearranges the result of a matrix product through shared memory + to match canonical tensor layouts in global memory. Epilogues support + conversion and reduction operations. + + This is a copy of cutlass/epilogue/threadblock/epilogue.h that can + handle "row_id" as a first argument, as uses it to get the corresponding + `m_prime` / `s_prime` to rescale the output. +*/ + +#pragma once + +#include + +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/functional.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/vector.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/transform/threadblock/regular_tile_iterator.h" +#include "cutlass/epilogue/threadblock/epilogue_base.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/thread/scale_type.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" +#include "epilogue_pipelined.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Applies a linear combination operator to an array of elements. +// output <- alpha * accumulator + beta * source +// with: +// alpha = 1 / s_prime (to normalize when isLast=True, 1 otherwise) +// beta = alpha / m_prime (renormalize the output when the max changes) +// source is the current output +template < + typename ElementOutput_, ///< Data type used to store tensors + typename ElementSource_, //< Data type for source (usually matches + //`ElementOutput`) + int Count, ///< Number of elements computed per operation. + ///< Usually it is 128/sizeof_bits, + ///< but we use 64 or 32 sometimes when there are not enough data + ///< to store + typename ElementAccumulator_, ///< Accumulator data type + typename ElementCompute_, ///< Data type used to compute linear combination + bool isFirst, + bool isLast, + typename FragmentAlphaBeta_, + FloatRoundStyle Round = FloatRoundStyle::round_to_nearest> +class MemoryEfficientAttentionNormalize { + public: + using ElementOutput = ElementOutput_; + using ElementSource = ElementSource_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + + static int const kCount = Count; + + using FragmentOutput = Array; + using FragmentSource = Array; + using FragmentAccumulator = Array; + using ComputeFragment = Array; + using FragmentAlphaBeta = FragmentAlphaBeta_; + + static FloatRoundStyle const kRound = Round; + + private: + // + // Data members + // + + FragmentAlphaBeta const& s_prime_; + FragmentAlphaBeta const& m_prime_; + + public: + /// Constructs the function object, possibly loading from pointers in host + /// memory + CUTLASS_HOST_DEVICE + MemoryEfficientAttentionNormalize( + FragmentAlphaBeta const& s_prime, + FragmentAlphaBeta const& m_prime) + : s_prime_(s_prime), m_prime_(m_prime) {} + + /// Returns true if source is needed + CUTLASS_HOST_DEVICE + bool is_source_needed() const { + return !isFirst; + } + + /// Functionally required for serial reduction in the epilogue + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition, int k_partition_count) {} + + /// Computes linear scaling: D = alpha * accumulator + beta * source + CUTLASS_HOST_DEVICE + FragmentOutput operator()( + int row, + FragmentAccumulator const& accumulator, + FragmentSource const& source) const { + assert(!isFirst); + + // Convert source to interal compute numeric type + NumericArrayConverter + source_converter; + NumericArrayConverter + accumulator_converter; + + // Convert to destination numeric type + NumericArrayConverter + destination_converter; + + ComputeFragment converted_source = source_converter(source); + ComputeFragment converted_accumulator = accumulator_converter(accumulator); + + // Perform binary operations + ComputeFragment intermediate; + + multiplies mul_add_source; + multiply_add mul_add_accumulator; + + ElementCompute alpha = isLast ? (1 / s_prime_[row]) : 1; + ElementCompute beta = alpha * m_prime_[row]; + + intermediate = mul_add_source(beta, converted_source); // X = beta * C + + intermediate = mul_add_accumulator( + alpha, converted_accumulator, intermediate); // D = alpha * Accum + X + + return destination_converter(intermediate); + } + + /// Computes linear scaling: D = alpha * accumulator + CUTLASS_HOST_DEVICE + FragmentOutput operator()(int row, FragmentAccumulator const& accumulator) + const { + assert(isFirst); + + // Convert source to interal compute numeric type + NumericArrayConverter + accumulator_converter; + + // Convert to destination numeric type + NumericArrayConverter + destination_converter; + + ComputeFragment converted_accumulator = accumulator_converter(accumulator); + + ComputeFragment intermediate; + multiplies mul_accumulator; + + ElementCompute alpha = isLast ? (1 / s_prime_[row]) : 1; + + intermediate = mul_accumulator( + alpha, converted_accumulator); // X = alpha * C + uniform + + return destination_converter(intermediate); + } +}; + +} // namespace thread + +namespace threadblock { +template < + typename EO, + typename ES, + int Count, + typename EA, + typename EC, + bool F, + bool L, + typename FAB, + FloatRoundStyle R> +struct ApplyEpilogueOp> { + using Op = thread:: + MemoryEfficientAttentionNormalize; + static CUTLASS_DEVICE typename Op::FragmentOutput apply( + Op const& output_op, + int row_id, + typename Op::FragmentAccumulator const& accum, + typename Op::FragmentSource const& source) { + return output_op(row_id, accum, source); + } + static CUTLASS_DEVICE typename Op::FragmentOutput apply( + Op const& output_op, + int row_id, + typename Op::FragmentAccumulator const& accum) { + return output_op(row_id, accum); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/41_fused_multi_head_attention/epilogue/epilogue_thread_apply_logsumexp.h b/examples/41_fused_multi_head_attention/epilogue/epilogue_thread_apply_logsumexp.h new file mode 100644 index 0000000000..bc2a28c0cf --- /dev/null +++ b/examples/41_fused_multi_head_attention/epilogue/epilogue_thread_apply_logsumexp.h @@ -0,0 +1,174 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing linear combination operations used by epilogues. +*/ + +#pragma once + +#include + +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +struct ArrayExponential { + CUTLASS_HOST_DEVICE + Array operator()( + Array const& input) const { + Array result; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ElementsPerAccess; ++i) { + result[i] = expf(input[i]); + } + + return result; + } +}; + +template +struct ArrayExponential { + CUTLASS_DEVICE + Array operator()( + Array const& input) const { + Array result; + + int const kVectorCount = ElementsPerAccess / 2; + + __half2 const* input_ptr = + reinterpret_cast<__half2 const*>(input.raw_data()); + __half2* res_ptr = reinterpret_cast<__half2*>(result.raw_data()); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kVectorCount; ++i) { + res_ptr[i] = h2exp(input_ptr[i]); + } + + return result; + } +}; +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Applies: +/// output <- (input - lse).exp() +template < + typename ElementOutput_, // output + typename ElementLSE_, // accumulator from LSE + typename ElementAccumulator_, // accumulator from matmul + typename ElementCompute_, // intermediate compute (and exp calculation) + int ElementsPerAccess> +class ApplyLogSumExp { + public: + using ElementOutput = ElementOutput_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + using ElementLSE = ElementLSE_; + + static int const kElementsPerAccess = ElementsPerAccess; + static int const kCount = kElementsPerAccess; + static const ScaleType::Kind kScale = + cutlass::epilogue::thread::ScaleType::NoBetaScaling; + + using FragmentOutput = Array; + using FragmentAccumulator = Array; + using FragmentCompute = Array; + using FragmentLSE = Array; + using FragmentScaleBias = FragmentLSE; // Used by epilogue_smem_accumulator.h + + public: + // + // Methods + // + + CUTLASS_HOST_DEVICE + ApplyLogSumExp() {} + + /// Returns true if source is needed + CUTLASS_HOST_DEVICE + bool is_source_needed() const { + return true; + } + + /// Functionally required for serial reduction in the epilogue + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition, int k_partition_count) {} + + CUTLASS_HOST_DEVICE + FragmentOutput operator()( + FragmentAccumulator const& AB, + FragmentLSE const& scale_unused, + // bias used as LSE + FragmentLSE const& bias) const { + FragmentCompute frag_AB = NumericArrayConverter< + ElementCompute, + ElementAccumulator, + kElementsPerAccess>()(AB); + FragmentCompute frag_lse_compute = + NumericArrayConverter()( + bias); + FragmentCompute frag_compute; + + minus minus_lse; + detail::ArrayExponential apply_exp; + frag_compute = minus_lse(frag_AB, frag_lse_compute); + frag_compute = apply_exp(frag_compute); + + return NumericArrayConverter< + ElementOutput, + ElementCompute, + kElementsPerAccess>()(frag_compute); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/41_fused_multi_head_attention/fmha_backward_test.py b/examples/41_fused_multi_head_attention/fmha_backward_test.py new file mode 100644 index 0000000000..cdea9ded4d --- /dev/null +++ b/examples/41_fused_multi_head_attention/fmha_backward_test.py @@ -0,0 +1,232 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import argparse +import torch +import sys +import os +from piped_subprocess import PipedSubprocess, TORCH_DTYPE_NAME +import math + + +parser = argparse.ArgumentParser() +parser.add_argument("example_exe", type=str, help="Path to the 41_fused_multi_head_attention_backward executable") +args = parser.parse_args() + +torch.manual_seed(0) +dtype = torch.float16 +B, Mq, Mkv, H, K, Kv = 2, 1024, 1024, 5, 128, 128 +causal = True +repeat_count = 100 + +ATOL = { + torch.float: 5e-4, + torch.half: 9.5e-2, + torch.bfloat16: 7e-1, +}[dtype] + +RTOL = { + torch.float: 1e-4, + torch.half: 2e-2, + torch.bfloat16: 1e-1, +}[dtype] + + +assert not (causal and Mq < Mkv), "causal only supports seqlenK <= seqlenQ" + +fmha_bw_binary = args.example_exe +if not os.path.isfile(fmha_bw_binary): + print(f"""No such file: `{fmha_bw_binary}`\nDid you forget to run "make 41_fused_multi_head_attention"?""") + sys.exit(1) + +def create_lower_triangular_mask(): + return torch.triu(torch.full( # type: ignore + [1, Mq, Mkv], + dtype=dtype, + fill_value=float("-inf"), + ), diagonal=1) + +def ref_mha_bmk(q, k, v, mask): + # Multi-head attention with inputs/outputs in BMK format + q = q.float() + k = k.float() + v = v.float() + + q = q * (1 / q.shape[-1] ** 0.5) + attn = q @ k.transpose(-2, -1) + if mask is not None: + attn += mask + attn_max = attn.max(-1, True).values + attn_norm = (attn - attn_max).exp().sum(-1, True) + attn = attn.softmax(-1) + lse = attn_max + attn_norm.log() + lse = lse.squeeze(2) + return attn @ v, lse + + +def bmhk2bmk(t): + return t.permute((0, 2, 1, 3)).reshape( + [t.shape[0] * t.shape[2], t.shape[1], t.shape[3]] + ) + +def ref_mha_bmhk(q, k, v, mask): + # Multi-head attention with inputs/outputs in BMHK format + assert q.ndim == 4 + + out, lse = ref_mha_bmk(bmhk2bmk(q), bmhk2bmk(k), bmhk2bmk(v), mask=mask) + out = out.reshape([q.shape[0], q.shape[2], q.shape[1], v.shape[3]]) + return out.permute((0, 2, 1, 3)), lse.reshape([q.shape[0], q.shape[2], q.shape[1]]) + +def ref_mha_bw_bmhk(q, k, v, mask, lse, out, grad_out, delta): + lse = lse[:, :, :q.shape[1]] #BMH, unpad Q dimension + delta = delta.reshape([-1, delta.shape[-1], 1]) + + # bmhk -> bmk + q, k, v, out, grad_out = [bmhk2bmk(x).float() for x in (q, k, v, out, grad_out)] + + attn_T = k @ q.transpose(-2, -1) + if mask is not None: + attn_T += mask.transpose(-2, -1) + attn_T = attn_T * (1 / q.shape[-1] ** 0.5) + attn_T = attn_T - lse.reshape([-1, 1, lse.shape[-1]]) + attn_T = attn_T.exp() + + grad_v = attn_T @ grad_out + + dov = grad_out @ v.transpose(-2, -1) + tmp = (dov - delta) * attn_T.transpose(-2, -1) + tmp = tmp / (q.shape[-1] ** 0.5) + + grad_q = tmp @ k + grad_k = tmp.transpose(-2, -1) @ q + + return [x.reshape([B, H, x.shape[1], x.shape[-1]]).permute([0, 2, 1, 3]) for x in [grad_q, grad_k, grad_v]] + + +print("initializing tensors...") +query = torch.randn([B, Mq, H, K], dtype=dtype) +key = 3 * torch.randn([B, Mkv, H, K], dtype=dtype) +value = 3 * torch.randn([B, Mkv, H, Kv], dtype=dtype) +mask = create_lower_triangular_mask() if causal else None + +# let PyTorch compute gradients +query.requires_grad_(True) +key.requires_grad_(True) +value.requires_grad_(True) + +print("computing fw...") +out, lse = ref_mha_bmhk(query, key, value, mask=mask) +out = out.to(dtype).contiguous() +grad_out = 3 * torch.randn([B, Mq, H, Kv], dtype=dtype) + +print("computing bw with autograd...") +out.backward(grad_out) +scale = (1 / query.shape[-1] ** 0.5) + + +# Additional data needed by the kernel +delta = (grad_out.float() * out.float()).sum(-1).transpose(-2, -1).contiguous() +pad_amount = (32 - (lse.shape[2] % 32)) % 32 +lse = torch.nn.functional.pad(lse, [0, pad_amount], value=math.inf) + +print("computing bw with reference implem...") +gQr, gKr, gVr = ref_mha_bw_bmhk(query, key, value, mask, lse, out, grad_out, delta) + +with PipedSubprocess(fmha_bw_binary) as bw_kernel: + # Send kernel arguments + bw_kernel.write( + TORCH_DTYPE_NAME[query.dtype], + "scale", scale, + "head_dim", K, + "head_dim_value", Kv, + "num_queries", Mq, + "num_keys", Mkv, + "num_heads", H, + "custom_mask_type", (1 if causal else 0), + "num_batches", B, + "repeat_count", repeat_count, + "num_splits_key", (Mkv // 128), + ) + bw_kernel.writeTensor(query, "query", ["q_strideB", "q_strideM", "q_strideH"]) + bw_kernel.writeTensor(key, "key", ["k_strideB", "k_strideM", "k_strideH"]) + bw_kernel.writeTensor(value, "value", ["v_strideB", "v_strideM", "v_strideH"]) + bw_kernel.writeTensor(lse, "logsumexp", ["lse_strideB", "lse_strideH"]) + bw_kernel.writeTensor(out, "output", ["o_strideB", "o_strideM", "o_strideH"]) + bw_kernel.writeTensor(grad_out, "grad_output", ["gO_strideB", "gO_strideM", "gO_strideH"]) + bw_kernel.writeTensor(delta, "delta", ["delta_strideB", "delta_strideH"]) + + if bw_kernel.read() != "OK": + print("Got unexpected output") + print(bw_kernel.subp.communicate()[0]) + sys.exit(0) + + # Read kernel output + gQ = bw_kernel.readTensor("grad_query", ["gQ_strideB", "gQ_strideM", "gQ_strideH"], query.shape).float() + gK = bw_kernel.readTensor("grad_key", ["gK_strideB", "gK_strideM", "gK_strideH"], key.shape).float() + gV = bw_kernel.readTensor("grad_value", ["gV_strideB", "gV_strideM", "gV_strideH"], value.shape).float() + runtime_ms = float(bw_kernel.readNamed("runtime_ms")) + +float_ops = B * H * sum([ + # att = Q @ K.transpose + Mq * Mkv * K * 2, + # att @ dO + Mkv * Mq * Kv * 2, + # dov = dO @ V + Mq * Kv * Mkv * 2, + # dov @ K + Mq * K * Mkv * 2, + # dov @ Q + Mq * K * Mkv * 2, +]) +if causal: + float_ops //= 2 + +print(f""" +Fused multi-head attention - backward + batch_size={B} + num_queries={Mq} + num_keys={Mkv} + num_heads={H} + head_dim={K} + head_dim_value={Kv} + + Correctness: + grad_query: {"PASS" if torch.allclose(gQ, gQr, rtol=RTOL, atol=ATOL) else "FAIL"} (delta: {(gQ - gQr).abs().max()}) + grad_key: {"PASS" if torch.allclose(gK, gKr, rtol=RTOL, atol=ATOL) else "FAIL"} (delta: {(gK - gKr).abs().max()}) + grad_value: {"PASS" if torch.allclose(gV, gVr, rtol=RTOL, atol=ATOL) else "FAIL"} (delta: {(gV - gVr).abs().max()}) + (atol={ATOL} / rtol={RTOL}) + Runtime: {runtime_ms}ms ({(float_ops / (1024 ** 4)) / (runtime_ms / 1000):.4f} TFlops) +""") + +assert torch.allclose(query.grad.float(), gQr, rtol=RTOL, atol=ATOL), "Reference implementation does not match PyTorch autograd!" +assert torch.allclose(key.grad.float(), gKr, rtol=RTOL, atol=ATOL), "Reference implementation does not match PyTorch autograd!" +assert torch.allclose(value.grad.float(), gVr, rtol=RTOL, atol=ATOL), "Reference implementation does not match PyTorch autograd!" diff --git a/examples/41_fused_multi_head_attention/fmha_grouped.h b/examples/41_fused_multi_head_attention/fmha_grouped.h new file mode 100644 index 0000000000..5a2f928ad8 --- /dev/null +++ b/examples/41_fused_multi_head_attention/fmha_grouped.h @@ -0,0 +1,1023 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Grouped FMHA kernel +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/complex.h" +#include "cutlass/semaphore.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/trace.h" +#include "cutlass/gemm/kernel/gemm_transpose_operands.h" + +#include "fmha_grouped_problem_visitor.h" +#include "gemm_kernel_utils.h" +#include "gemm/mma_accum_lambda_iterator.h" +#include "epilogue/epilogue_rescale_output.h" + + +namespace { + static CUTLASS_DEVICE float atomicMaxFloat(float* addr, float value) { + // source: https://stackoverflow.com/a/51549250 + return (value >= 0) + ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) + : __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value))); +} +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename MM0_, ///! Structure for computing P = Q @ K + typename MM1_, ///! Structure for computing O = P @ V + typename scalar_t_, + typename accum_t_, + typename output_t_, + typename output_accum_t_, + bool kKeepOutputInRF, ///! Whether the intermediate output from MM0_ should be kept in the register file + GroupScheduleMode GroupScheduleMode_ ///! Type of scheduling to perform +> +struct FMHAGrouped { +public: + using MM0 = MM0_; + using MM1 = MM1_; + + using scalar_t = scalar_t_; + using accum_t = accum_t_; + using output_t = output_t_; + using output_accum_t = output_accum_t_; + + static GroupScheduleMode const kGroupScheduleMode = GroupScheduleMode_; + + static constexpr bool kNeedsOutputAccumulatorBuffer = !kKeepOutputInRF && + !cutlass::platform::is_same::value; + + // Parameters to satisfy BaseGrouped + using ElementA = scalar_t; + using ElementB = scalar_t; + using ElementC = accum_t; + using LayoutA = typename MM0::LayoutA; + using LayoutB = typename MM0::ElementB; + using LayoutC = typename MM1::ElementC; + static ComplexTransform const kTransformA = ComplexTransform::kNone; + static ComplexTransform const kTransformB = ComplexTransform::kNone; + static int const kAlignmentA = MM0::kAlignmentA; + static int const kAlignmentB = MM0::kAlignmentB; + static int const kAlignmentC = 1; + using Mma = typename MM1::Mma; + using EpilogueOutputOp = typename MM1::EpilogueOutputOp; + using ThreadblockSwizzle = void; + using Operator = typename MM1::Operator; + using WarpShape = typename MM1::WarpShape; + using InstructionShape = typename MM1::InstructionShape; + + using ElementQ = scalar_t; + using ElementK = scalar_t; + using ElementP = accum_t; + using ElementV = scalar_t; + using ElementO = output_t; + using ElementOAccum = output_accum_t; + using ElementAccumulator = accum_t; + + using LayoutQ = typename MM0::LayoutA; + using LayoutK = typename MM0::LayoutB; + using LayoutP = typename MM0::LayoutC; + using LayoutV = typename MM1::LayoutB; + using LayoutO = typename MM1::LayoutC; + + static bool const kPreloadV = (MM1::Mma::ArchTag::kMinComputeCapability >= 80 && + cutlass::sizeof_bits::value == 16); + + static int const kAlignmentQ = MM0::kAlignmentA; + static int const kAlignmentK = MM0::kAlignmentB; + static int const kAlignmentV = 1; + + using ThreadblockShape = typename MM0::ThreadblockShape; + + static int const kQueriesPerBlock = ThreadblockShape::kM; + static int const kKeysPerBlock = ThreadblockShape::kN; + + static constexpr bool kSupportsDropout = false; + static constexpr bool kSupportsBias = false; + + /// Warp count (concept: GemmShape) + using WarpCount = typename MM1::WarpCount; + static int const kThreadsPerWarp = 32; + static int const kThreadCount = kThreadsPerWarp * WarpCount::kCount; + + static constexpr int kNumWarpsPerBlock = + kQueriesPerBlock * kKeysPerBlock / (kThreadsPerWarp * kThreadsPerWarp); + + using ProblemVisitor = FMHAGroupedProblemVisitor< + ThreadblockShape, + kGroupScheduleMode, + kThreadCount, + kThreadCount>; + + // + // Structures + // + + /// Argument structure + struct Arguments { + + // + // Data members + // + + GemmCoord *problem_sizes0{nullptr}; + GemmCoord *problem_sizes1{nullptr}; + + int problem_count{0}; + int threadblock_count{0}; + + ElementQ ** ptr_Q{nullptr}; + ElementK ** ptr_K{nullptr}; + ElementP ** ptr_P{nullptr}; + ElementV ** ptr_V{nullptr}; + ElementO ** ptr_O{nullptr}; + ElementOAccum ** ptr_O_accum{nullptr}; + + typename LayoutQ::Stride::LongIndex *ldq{nullptr}; + typename LayoutK::Stride::LongIndex *ldk{nullptr}; + typename LayoutP::Stride::LongIndex *ldv{nullptr}; + typename LayoutO::Stride::LongIndex *ldo{nullptr}; + + // Whether causal masking is to be performed + bool causal{false}; + + // Scale + ElementAccumulator scale{0}; + + // Only used by device-level operator + GemmCoord *host_problem_sizes{nullptr}; + + // + // Methods + // + + /// Default ctor + Arguments() = default; + + /// Ctor + CUTLASS_HOST_DEVICE + Arguments( + GemmCoord *problem_sizes0, + GemmCoord *problem_sizes1, + int problem_count, + int threadblock_count, + ElementQ ** ptr_Q, + ElementK ** ptr_K, + ElementP ** ptr_P, + ElementV ** ptr_V, + ElementO ** ptr_O, + ElementOAccum ** ptr_O_accum, + typename LayoutQ::Stride::LongIndex *ldq, + typename LayoutK::Stride::LongIndex *ldk, + typename LayoutP::Stride::LongIndex *ldp, + typename LayoutV::Stride::LongIndex *ldv, + typename LayoutO::Stride::LongIndex *ldo, + bool causal, + ElementAccumulator scale, + GemmCoord *host_problem_sizes=nullptr + ): + problem_sizes0(problem_sizes0), + problem_sizes1(problem_sizes1), + problem_count(problem_count), + threadblock_count(threadblock_count), + ptr_Q(ptr_Q), + ptr_K(ptr_K), + ptr_P(ptr_P), + ptr_V(ptr_V), + ptr_O(ptr_O), + ptr_O_accum(kNeedsOutputAccumulatorBuffer ? ptr_O_accum : (accum_t**)ptr_O), + ldq(ldq), + ldk(ldk), + ldv(ldv), + ldo(ldo), + causal(causal), + scale(scale), + host_problem_sizes(host_problem_sizes) + { + + } + + bool __host__ check_supported() { + CHECK_ALIGNED_PTR(ptr_Q, kAlignmentQ); + CHECK_ALIGNED_PTR(ptr_K, kAlignmentK); + CHECK_ALIGNED_PTR(ptr_V, kAlignmentV); + XFORMERS_CHECK(ldq % kAlignmentQ == 0, "query is not correctly aligned"); + XFORMERS_CHECK(ldk % kAlignmentK == 0, "key is not correctly aligned"); + XFORMERS_CHECK(ldv % kAlignmentV == 0, "value is not correctly aligned"); + return true; + } + }; + + // + // Structure for precomputing values in host memory and passing to kernels + // + + /// Parameters structure + struct Params { + + typename ProblemVisitor::Params problem_visitor; + int threadblock_count; + + ElementQ ** ptr_Q; + ElementK ** ptr_K; + ElementP ** ptr_P; + ElementV ** ptr_V; + ElementO ** ptr_O; + ElementOAccum ** ptr_O_accum; + + typename LayoutQ::Stride::LongIndex *ldq; + typename LayoutK::Stride::LongIndex *ldk; + typename LayoutP::Stride::LongIndex *ldv; + typename LayoutO::Stride::LongIndex *ldo; + + ElementAccumulator scale; + bool causal; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): + ptr_Q(nullptr), + ptr_K(nullptr), + ptr_P(nullptr), + ptr_V(nullptr), + ptr_O(nullptr), + ptr_O_accum(nullptr), + ldq(nullptr), + ldk(nullptr), + ldv(nullptr), + ldo(nullptr), + causal(false), + scale(0) + { } + + CUTLASS_HOST_DEVICE + Params(Arguments const &args, + void *workspace = nullptr, + int tile_count = 0): + problem_visitor(args.problem_sizes0, args.problem_sizes1, args.problem_count, workspace, tile_count), + threadblock_count(args.threadblock_count), + ptr_Q(args.ptr_Q), + ptr_K(args.ptr_K), + ptr_P(args.ptr_P), + ptr_V(args.ptr_V), + ptr_O(args.ptr_O), + ptr_O_accum(kNeedsOutputAccumulatorBuffer ? args.ptr_O_accum : (accum_t**)args.ptr_O), + ldq(args.ldq), + ldk(args.ldk), + ldv(args.ldv), + ldo(args.ldo), + causal(args.causal), + scale(args.scale) + { + + } + + CUTLASS_HOST_DEVICE + void update( + Arguments const &args, + void *workspace = nullptr, + int tile_count = 0) { + + problem_visitor = typename ProblemVisitor::Params(args.problem_sizes0, + args.problem_sizes1, + args.problem_count, + workspace, tile_count); + threadblock_count = args.threadblock_count; + ptr_Q = args.ptr_Q; + ptr_K = args.ptr_K; + ptr_P = args.ptr_P; + ptr_V = args.ptr_V; + ptr_O = args.ptr_O; + ptr_O_accum = kNeedsOutputAccumulatorBuffer ? args.ptr_O_accum : (accum_t**)args.ptr_O; + ldq = args.ldq; + ldk = args.ldk; + ldv = args.ldv; + ldo = args.ldo; + causal = args.causal; + scale = args.scale; + } + }; + + // Shared storage - depends on kernel params + struct ScalingCoefs { + cutlass::Array m_prime; + cutlass::Array s_prime; + cutlass::Array mi; + cutlass::Array out_rescale; + cutlass::Array + addition_storage; + }; + + struct SharedStorageEpilogueAtEnd : ScalingCoefs { + struct SharedStorageAfterMM0 { + // Everything here might be overwritten during MM0 + typename MM0::AccumulatorSharedStorage si; + typename MM1::Mma::SharedStorage mm1; + }; + + union { + typename MM0::Mma::SharedStorage mm0; + SharedStorageAfterMM0 after_mm0; + typename MM1::DefaultEpilogue::SharedStorage epilogue; + }; + + CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage& + epilogue_shared_storage() { + return epilogue; + } + + // ProblemVisitor shared storage can't be overlapped with others + typename ProblemVisitor::SharedStorage problem_visitor; + }; + + struct SharedStorageEpilogueInLoop : ScalingCoefs { + struct SharedStorageAfterMM0 { + // Everything here might be overwritten during MM0 + typename MM0::AccumulatorSharedStorage si; + typename MM1::Mma::SharedStorage mm1; + typename MM1::DefaultEpilogue::SharedStorage epilogue; + }; + + union { + typename MM0::Mma::SharedStorage mm0; + SharedStorageAfterMM0 after_mm0; + }; + + CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage& + epilogue_shared_storage() { + return after_mm0.epilogue; + } + + // ProblemVisitor shared storage can't be overlapped with others + typename ProblemVisitor::SharedStorage problem_visitor; + }; + + using SharedStorage = typename cutlass::platform::conditional< + kKeepOutputInRF, + SharedStorageEpilogueAtEnd, + SharedStorageEpilogueInLoop>::type; + +private: + + // Parameters to be used by an individual tile + struct TileParams { + + CUTLASS_HOST_DEVICE + static int query_start(int threadblock_idx) { + return threadblock_idx * kQueriesPerBlock; + } + + // Returns whether this threadblock computes within the number of queries, + // which is determined by the M dimension of problem 0 + CUTLASS_HOST_DEVICE + static bool can_compute(int threadblock_idx, const GemmCoord& problem_size0) { + return query_start(threadblock_idx) < problem_size0.m(); + } + + CUTLASS_HOST_DEVICE + static int num_queries(int threadblock_idx, const GemmCoord& problem_size0) { + return problem_size0.m() - query_start(threadblock_idx); + } + + CUTLASS_HOST_DEVICE + static int num_keys(int threadblock_idx, const GemmCoord& problem_size0, bool causal) { + int nk = problem_size0.n(); + if (causal) { + nk = cutlass::fast_min(int32_t(query_start(threadblock_idx) + kQueriesPerBlock), nk); + } + return nk; + } + + }; + +public: + + // + // Methods + // + + CUTLASS_DEVICE + FMHAGrouped() { } + + /// Determines whether kernel satisfies alignment + static Status can_implement(cutlass::gemm::GemmCoord const & problem_size) { + return Status::kSuccess; + } + + static Status can_implement(Arguments const &args) { + return Status::kSuccess; + } + + static CUTLASS_DEVICE int16_t thread_id() { + return threadIdx.x; + } + + static CUTLASS_DEVICE int8_t warp_id() { + return threadIdx.x / kThreadsPerWarp; + } + + static CUTLASS_DEVICE int8_t lane_id() { + return threadIdx.x % kThreadsPerWarp; + } + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + auto& m_prime = shared_storage.m_prime; + auto& s_prime = shared_storage.s_prime; + [[maybe_unused]] auto& si = shared_storage.after_mm0.si; + auto& mi = shared_storage.mi; + auto& out_rescale = shared_storage.out_rescale; + + ProblemVisitor problem_visitor( + params.problem_visitor, + shared_storage.problem_visitor, + blockIdx.x); + + // Outer 'persistent' loop to iterate over tiles + while (problem_visitor.next_tile()) { + + GemmCoord problem_size0 = problem_visitor.problem_size0(); + GemmCoord problem_size1 = problem_visitor.problem_size1(); + const int32_t threadblock_idx = int32_t(problem_visitor.threadblock_idx()); + + if (!TileParams::can_compute(threadblock_idx, problem_size0)) { + problem_visitor.advance(gridDim.x); + continue; + } + + const int32_t problem_idx = problem_visitor.problem_index(); + + if (thread_id() < kQueriesPerBlock) { + s_prime[thread_id()] = ElementAccumulator(0); + out_rescale[thread_id()] = accum_t(1.0); + m_prime[thread_id()] = + -cutlass::platform::numeric_limits::infinity(); + mi[thread_id()] = -cutlass::platform::numeric_limits::infinity(); + } + + ElementO *ptr_O = params.ptr_O[problem_idx] + TileParams::query_start(threadblock_idx) * params.ldo[problem_idx]; + ElementOAccum *ptr_O_accum = params.ptr_O_accum[problem_idx] + TileParams::query_start(threadblock_idx) * params.ldo[problem_idx]; + const int num_queries = TileParams::num_queries(threadblock_idx, problem_size0); + + auto createOutputIter = [&](int col) -> typename MM1::OutputTileIterator { + using OutputTileIterator = typename MM1::OutputTileIterator; + return OutputTileIterator( + typename OutputTileIterator::Params{(int32_t)params.ldo[problem_idx]}, + ptr_O, + typename OutputTileIterator::TensorCoord{ + num_queries, problem_size1.n()}, + thread_id(), + {0, col}); + }; + + auto createOutputAccumIter = [&](int col) -> + typename MM1::OutputTileIteratorAccum { + using OutputTileIteratorAccum = typename MM1::OutputTileIteratorAccum; + return OutputTileIteratorAccum( + typename OutputTileIteratorAccum::Params{(int32_t)params.ldo[problem_idx]}, + ptr_O_accum, + typename OutputTileIteratorAccum::TensorCoord{ + num_queries, problem_size1.n()}, + thread_id(), + {0, col}); + }; + + typename MM1::Mma::FragmentC accum_o; + accum_o.clear(); + + const int num_keys = TileParams::num_keys(threadblock_idx, problem_size0, params.causal); + + for (int32_t iter_key_start = 0; iter_key_start < num_keys; + iter_key_start += kKeysPerBlock) { + int32_t problem_size_0_m = + cutlass::fast_min((int32_t)kQueriesPerBlock, num_queries); + int32_t problem_size_0_n = cutlass::fast_min( + (int32_t)kKeysPerBlock, num_keys - iter_key_start); + int32_t const& problem_size_0_k = problem_size0.k(); + int32_t const& problem_size_1_n = problem_size1.n(); + int32_t const& problem_size_1_k = problem_size_0_n; + + auto prologueV = [&](int blockN) { + typename MM1::Mma::IteratorB iterator_V( + typename MM1::IteratorB::Params{typename MM1::LayoutB(params.ldv[problem_idx])}, + params.ptr_V[problem_idx] + iter_key_start * params.ldv[problem_idx], + {problem_size_1_k, problem_size_1_n}, + thread_id(), + cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN}); + + MM1::Mma::prologue( + shared_storage.after_mm0.mm1, + iterator_V, + thread_id(), + problem_size_1_k); + }; + + __syncthreads(); // Need to have shared memory initialized, and `m_prime` + // updated from end of prev iter + + // + // MATMUL: Q.K_t + // + // Computes the block-matrix product of: + // (a) query[query_start:query_end, :] + // with + // (b) key[iter_key_start:iter_key_start + kKeysPerBlock] + // and stores that into `shared_storage.si` + // + + ElementQ *ptr_Q = params.ptr_Q[problem_idx] + TileParams::query_start(threadblock_idx) * params.ldq[problem_idx]; + + // Construct iterators to A and B operands + typename MM0::IteratorA iterator_A( + typename MM0::IteratorA::Params( + typename MM0::MmaCore::LayoutA(params.ldq[problem_idx])), + ptr_Q, + {problem_size_0_m, problem_size_0_k}, + thread_id(), + {0, 0}); + + typename MM0::IteratorB iterator_B( + typename MM0::IteratorB::Params( + typename MM0::MmaCore::LayoutB(params.ldk[problem_idx])), + params.ptr_K[problem_idx] + iter_key_start * params.ldk[problem_idx], + {problem_size_0_k, problem_size_0_n}, + thread_id(), + {0, 0}); + + // Construct thread-scoped matrix multiply + typename MM0::Mma mma( + shared_storage.mm0, thread_id(), warp_id(), lane_id()); + + typename MM0::Mma::FragmentC accum; + + accum.clear(); + + auto gemm_k_iterations = + (problem_size_0_k + MM0::Mma::Shape::kK - 1) / MM0::Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); + __syncthreads(); + + if (kPreloadV) { + prologueV(0); + } else { + MM1::Mma::drain_cp_asyncs(); + } + + typename MM0::Mma::Operator::IteratorC::TensorCoord + iteratorC_tile_offset = { + (warp_id() % MM0::Mma::WarpCount::kM), + (warp_id() / MM0::Mma::WarpCount::kM) + }; + + // Mask out last if causal + if (params.causal && num_keys - iter_key_start <= kKeysPerBlock) { + auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset( + lane_id(), warp_id(), iteratorC_tile_offset); + int32_t last_col; + MM0::AccumLambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { + last_col = TileParams::query_start(threadblock_idx) + accum_m - iter_key_start; + }, + [&](int accum_m, int accum_n, int idx) { + if (accum_n > last_col) { + accum[idx] = + -cutlass::platform::numeric_limits::infinity(); + } + }, + [&](int accum_m) {}); + } + // DISPATCH_BOOL(iter_key_start == 0, kIsFirst, ([&] { + // DISPATCH_BOOL( + // num_keys - iter_key_start >= kKeysPerBlock, + // kFullColumns, + // ([&] { + // // Update `mi` from accum stored in registers + // // Also does accum[i] <- exp(accum[i] - mi) + // iterative_softmax< + // typename MM0::Mma::Operator::IteratorC, + // kFullColumns, + // kIsFirst>( + // accum_o, + // accum, + // mi, + // m_prime, + // s_prime, + // lane_id(), + // thread_id(), + // warp_id(), + // num_keys - iter_key_start, + // iteratorC_tile_offset, + // kSupportsBias ? 1.0f : params.scale); + // })); + // })); + + // Update `mi` from accum stored in registers + // Also does accum[i] <- exp(accum[i] - mi) + iterative_softmax( + accum_o, + accum, + mi, + m_prime, + s_prime, + out_rescale, + shared_storage.addition_storage, + lane_id(), + thread_id(), + warp_id(), + num_keys - iter_key_start, + iter_key_start == 0, + iteratorC_tile_offset, + kSupportsBias ? 1.0f : params.scale); + + // Output results to shared-memory + int warp_idx_mn_0 = warp_id() % + (MM0::Mma::Base::WarpCount::kM * MM0::Mma::Base::WarpCount::kN); + auto output_tile_coords = cutlass::MatrixCoord{ + warp_idx_mn_0 % MM0::Mma::Base::WarpCount::kM, + warp_idx_mn_0 / MM0::Mma::Base::WarpCount::kM}; + + MM0::B2bGemm::accumToSmem( + shared_storage.after_mm0.si, accum, lane_id(), output_tile_coords); + + __syncthreads(); + + // + // MATMUL: Attn . V + // Run the matmul `attn @ V` for a block of attn and V. + // `attn` is read from shared memory (in `shared_storage_si`) + // `V` is read from global memory (with iterator_B) + // + + const int64_t nBlockN = kKeepOutputInRF ? 1 + : ceil_div( + (int64_t)problem_size_1_n, + int64_t(MM1::ThreadblockShape::kN)); + + // Iterate over the N dimension of GEMM1 + for (int blockN = 0; blockN < nBlockN; ++blockN) { + int gemm_k_iterations = + (problem_size_1_k + MM1::Mma::Shape::kK - 1) / MM1::Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add and store it in accum + // (in registers) + if (!kPreloadV) { + __syncthreads(); // we share shmem between mma and epilogue + } + + typename MM1::Mma::IteratorB iterator_V( + typename MM1::IteratorB::Params{typename MM1::LayoutB(params.ldv[problem_idx])}, + params.ptr_V[problem_idx] + iter_key_start * params.ldv[problem_idx], + {problem_size_1_k, problem_size_1_n}, + thread_id(), + cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN}); + + typename MM1::Mma mma_pv( + // operand A: Pij_dropped in shared memory + shared_storage.after_mm0.si.accum_ref(), + // operand B: shared memory staging area for Vj, which is loaded + // from global memory + shared_storage.after_mm0.mm1.operand_B_ref(), + (int)thread_id(), + (int)warp_id(), + (int)lane_id()); + + mma_pv.set_prologue_done(kPreloadV); + if (!kKeepOutputInRF) { + accum_o.clear(); + } + + mma_pv(gemm_k_iterations, accum_o, iterator_V, accum_o); + __syncthreads(); + + if (kPreloadV && !kKeepOutputInRF && blockN + 1 < nBlockN) { + prologueV(blockN + 1); + } + + if (!kKeepOutputInRF) { + MM1::Mma::drain_cp_asyncs(); + DISPATCH_BOOL( + iter_key_start == 0, kIsFirst, ([&] { + DISPATCH_BOOL( + (iter_key_start + kKeysPerBlock) >= num_keys, + kIsLast, + ([&] { + using DefaultEpilogue = typename MM1::DefaultEpilogue; + using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp; + using ElementCompute = typename DefaultOp::ElementCompute; + using EpilogueOutputOp = typename cutlass::epilogue:: + thread::MemoryEfficientAttentionNormalize< + typename cutlass::platform::conditional< + kIsLast::value, + output_t, + output_accum_t>::type, + output_accum_t, + DefaultOp::kCount, + typename DefaultOp::ElementAccumulator, + output_accum_t, + kIsFirst::value, + kIsLast::value, + cutlass::Array>; + using Epilogue = typename cutlass::epilogue::threadblock:: + EpiloguePipelined< + typename DefaultEpilogue::Shape, + typename MM1::Mma::Operator, + DefaultEpilogue::kPartitionsK, + typename cutlass::platform::conditional< + kIsLast::value, + typename MM1::OutputTileIterator, + typename MM1::OutputTileIteratorAccum>::type, + typename DefaultEpilogue:: + AccumulatorFragmentIterator, + typename DefaultEpilogue::WarpTileIterator, + typename DefaultEpilogue::SharedLoadIterator, + EpilogueOutputOp, + typename DefaultEpilogue::Padding, + DefaultEpilogue::kFragmentsPerIteration, + true, // IterationsUnroll + typename MM1::OutputTileIteratorAccum // Read + // iterator + >; + + int col = blockN * MM1::Mma::Shape::kN; + auto source_iter = createOutputAccumIter(col); + auto dest_iter = gemm_kernel_utils::call_conditional< + kIsLast::value, + decltype(createOutputIter), + decltype(createOutputAccumIter)>:: + apply(createOutputIter, createOutputAccumIter, col); + EpilogueOutputOp rescale(s_prime, out_rescale); + Epilogue epilogue( + shared_storage.epilogue_shared_storage(), + thread_id(), + warp_id(), + lane_id()); + epilogue(rescale, dest_iter, accum_o, source_iter); + })); + })); + if (!kKeepOutputInRF) { + __syncthreads(); + } + } + } + __syncthreads(); // we modify `m_prime` after + } + + if (kKeepOutputInRF) { + constexpr bool kIsFirst = true; + constexpr bool kIsLast = true; + using DefaultEpilogue = typename MM1::DefaultEpilogue; + using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp; + using ElementCompute = typename DefaultOp::ElementCompute; + using EpilogueOutputOp = + typename cutlass::epilogue::thread::MemoryEfficientAttentionNormalize< + output_t, // output + output_accum_t, // source + DefaultOp::kCount, + typename DefaultOp::ElementAccumulator, // accum + output_accum_t, // compute + kIsFirst, + kIsLast, + cutlass::Array>; + using Epilogue = + typename cutlass::epilogue::threadblock::EpiloguePipelined< + typename DefaultEpilogue::Shape, + typename MM1::Mma::Operator, + DefaultEpilogue::kPartitionsK, + typename MM1::OutputTileIterator, // destination + typename DefaultEpilogue::AccumulatorFragmentIterator, + typename DefaultEpilogue::WarpTileIterator, + typename DefaultEpilogue::SharedLoadIterator, + EpilogueOutputOp, + typename DefaultEpilogue::Padding, + DefaultEpilogue::kFragmentsPerIteration, + true, // IterationsUnroll + typename MM1::OutputTileIteratorAccum // source tile + >; + auto dest_iter = createOutputIter(0); + EpilogueOutputOp rescale(s_prime, out_rescale); + Epilogue epilogue( + shared_storage.epilogue_shared_storage(), + thread_id(), + warp_id(), + lane_id()); + MM1::Mma::drain_cp_asyncs(); + epilogue(rescale, dest_iter, accum_o); + } + + // Next tile + problem_visitor.advance(gridDim.x); + __syncthreads(); // Don't start the next iteration until all threads are done using shared memory. + } + } + + template + CUTLASS_DEVICE static void iterative_softmax( + typename WarpIteratorC::Fragment& frag_o, // output so far + typename WarpIteratorC::Fragment& frag, + cutlass::Array& mi, + cutlass::Array& m_prime, + cutlass::Array& s_prime, + cutlass::Array& out_rescale, + cutlass::Array& + addition_storage, + int8_t lane_id, + int8_t thread_id, + int8_t warp_id, + int max_col, + bool is_first, + typename WarpIteratorC::TensorCoord const& tile_offset, + float scaling) { + /* Iterates on the accumulator and corresponding position on result matrix + + (1) Update `mi[r]` to the max value of the row `r` + (2) In a second iteration do the following: + (a) accum <- exp(accum - mi) + (b) m_prime <- exp(m_prime - mi) + (c) s_prime <- s_prime * m_prime + sum(accum) + + All of this is done on registers, before we store all of this + on shared memory for the next matmul with Value. + */ + using Fragment = typename WarpIteratorC::Fragment; + using LambdaIterator = typename DefaultMmaAccumLambdaIterator< + WarpIteratorC, + accum_t, + kThreadsPerWarp>::Iterator; + // Convert to `accum_t` (rather than double) + constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E + + static_assert(kQueriesPerBlock % kNumWarpsPerBlock == 0, ""); + static constexpr int kLinesPerWarp = kQueriesPerBlock / kNumWarpsPerBlock; + + frag = cutlass::multiplies()(scaling * kLog2e, frag); + + auto lane_offset = + LambdaIterator::get_lane_offset(lane_id, warp_id, tile_offset); + + // First update `mi` to the max per-row + { + accum_t max; + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { + max = -cutlass::platform::numeric_limits::infinity(); + }, + [&](int accum_m, int accum_n, int idx) { + if (accum_n < max_col) { + max = cutlass::fast_max(max, frag[idx]); + } + }, + [&](int accum_m) { + // Having 4x atomicMax seems faster than reduce within warp + // first... + atomicMaxFloat(&mi[accum_m], max); + }); + } + + // Make sure we all share the update values for `mi` + __syncthreads(); + + // Doing this `exp` is quite expensive. Let's + // split it across the warps + bool restore_mi_to_minus_inf = false; + if (lane_id < kLinesPerWarp) { + int id = warp_id * kLinesPerWarp + lane_id; + auto m_prime_id = m_prime[id]; + auto mi_id = mi[id]; + bool changed = m_prime_id < mi_id; // `false` if both are -inf + if (changed) { + auto m_prime_exp = exp2f(m_prime_id - mi_id); + out_rescale[id] = m_prime_exp; + s_prime[id] *= m_prime_exp; + } else { + // Only when bias is enabled, it's possible that all the first values + // of attention are masked to `-inf`. In that case we want to avoid + // `nan = exp2f(-inf - (-inf))` so we temporarily set `mi` to 0 + if (kSupportsBias && + mi_id == -cutlass::platform::numeric_limits::infinity()) { + restore_mi_to_minus_inf = true; + mi[id] = 0.0f; + } + out_rescale[id] = 1.0f; + } + } + __syncthreads(); // Update output fragments + if (kKeepOutputInRF && !is_first) { + accum_t line_rescale; + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { line_rescale = out_rescale[accum_m]; }, + [&](int accum_m, int accum_n, int idx) { + frag_o[idx] = frag_o[idx] * line_rescale; + }, + [&](int accum_m) {}); + } + // Update accum_m, accum_n, ... + { + accum_t mi_row, total_row; + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { mi_row = mi[accum_m]; }, + [&](int accum_m, int accum_n, int idx) { + frag[idx] = + (accum_n < max_col) ? exp2f(frag[idx] - mi_row) : accum_t(0.0); + }, + [&](int accum_m) {}); + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { total_row = 0.0; }, + [&](int accum_m, int accum_n, int idx) { total_row += frag[idx]; }, + [&](int accum_m) { + if (LambdaIterator::reduceSameRow( + lane_id, total_row, [](accum_t a, accum_t b) { + return a + b; + })) { + // NOTE: we could atomically add `total_row` to `s_prime`, but + // it's faster (and deterministic) to avoid atomics here + addition_storage + [accum_m + kQueriesPerBlock * tile_offset.column()] = + total_row; + } + }); + } + + __syncthreads(); + if (lane_id < kLinesPerWarp) { + int id = warp_id * kLinesPerWarp + lane_id; + accum_t total_row = s_prime[id]; + if (restore_mi_to_minus_inf) { + // Restore `mi`, see above when we set `restore_mi_to_minus_inf=true` + mi[id] = -cutlass::platform::numeric_limits::infinity(); + } else { + m_prime[id] = mi[id]; + } + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < MM0::MmaCore::WarpCount::kN; ++i) { + total_row += addition_storage[id + kQueriesPerBlock * i]; + } + s_prime[id] = total_row; + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/41_fused_multi_head_attention/fmha_grouped_problem_visitor.h b/examples/41_fused_multi_head_attention/fmha_grouped_problem_visitor.h new file mode 100644 index 0000000000..38695d5a81 --- /dev/null +++ b/examples/41_fused_multi_head_attention/fmha_grouped_problem_visitor.h @@ -0,0 +1,178 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Scheduler for grouped FMHA +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/gemm/kernel/grouped_problem_visitor.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { +// Helper for correctly representing problem sizes in grouped kernels +template +struct FMHAGroupedProblemSizeHelper { + + CUTLASS_HOST_DEVICE + static cutlass::gemm::GemmCoord grid_shape(const cutlass::gemm::GemmCoord& problem) { + // FMHA only partitions tiles across the M dimension. + return cutlass::gemm::GemmCoord( + ((problem.m() - 1 + ThreadblockShape::kM) / ThreadblockShape::kM), 1, 1); + } + + CUTLASS_HOST_DEVICE + static void possibly_transpose_problem(cutlass::gemm::GemmCoord& problem) {} + + CUTLASS_HOST_DEVICE + static int32_t tile_count(const cutlass::gemm::GemmCoord& grid) { + return grid.m() * grid.n(); + } +}; + +} // namespace detail + +/// Visitor class to abstract away the algorithm for iterating over tiles +template +struct FMHAGroupedProblemVisitor : public GroupedProblemVisitor< + detail::FMHAGroupedProblemSizeHelper, + ThreadblockShape, + GroupScheduleMode_, + PrefetchTileCount, + ThreadCount> { + + using ProblemSizeHelper = detail::FMHAGroupedProblemSizeHelper; + using Base = GroupedProblemVisitor; + using BaseParams = typename Base::Params; + using SharedStorage = typename Base::SharedStorage; + + cutlass::gemm::GemmCoord const *problem_sizes0; + cutlass::gemm::GemmCoord const *problem_sizes1; + + struct Params { + cutlass::gemm::GemmCoord const *problem_sizes0; + cutlass::gemm::GemmCoord const *problem_sizes1; + int32_t problem_count; + void const *workspace; + int32_t tile_count; + + // + // Methods + // + + /// Ctor + CUTLASS_HOST_DEVICE + Params(): problem_sizes0(nullptr), problem_sizes1(nullptr), + problem_count(0), workspace(nullptr), tile_count(0) { } + + /// Ctor + CUTLASS_HOST_DEVICE + Params( + cutlass::gemm::GemmCoord const *problem_sizes0, + cutlass::gemm::GemmCoord const *problem_sizes1, + int32_t problem_count, + void const *workspace = nullptr, + int32_t tile_count = 0 + ): + problem_sizes0(problem_sizes0), + problem_sizes1(problem_sizes1), + problem_count(problem_count), + workspace(workspace), + tile_count(tile_count) + {} + + /// Convert the FMHA-specific parameters to those used by the base class + CUTLASS_HOST_DEVICE + BaseParams to_base() const { + return BaseParams(// Set problem_sizes as problem_sizes1 because these determine + // shape of the final output of FMHA + problem_sizes1, + problem_count, + workspace, + tile_count); + } + + }; + + // + // Methods + // + CUTLASS_DEVICE + FMHAGroupedProblemVisitor( + Params const ¶ms_, + SharedStorage &shared_storage_, + int32_t block_idx + ): Base ( + params_.to_base(), + shared_storage_, block_idx), + problem_sizes0(params_.problem_sizes0), + problem_sizes1(params_.problem_sizes1) + {} + + /// Returns the problem size 0 for the current problem + CUTLASS_HOST_DEVICE + cutlass::gemm::GemmCoord problem_size0() const { + GemmCoord problem = problem_sizes0[this->problem_idx]; + ProblemSizeHelper::possibly_transpose_problem(problem); + return problem; + } + + /// Returns the problem size 1 for the current problem + CUTLASS_HOST_DEVICE + cutlass::gemm::GemmCoord problem_size1() const { + GemmCoord problem = problem_sizes1[this->problem_idx]; + ProblemSizeHelper::possibly_transpose_problem(problem); + return problem; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/41_fused_multi_head_attention/fused_multi_head_attention_backward.cu b/examples/41_fused_multi_head_attention/fused_multi_head_attention_backward.cu new file mode 100644 index 0000000000..544e400fc9 --- /dev/null +++ b/examples/41_fused_multi_head_attention/fused_multi_head_attention_backward.cu @@ -0,0 +1,298 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include +#include +#include + +#include "kernel_backward.h" + +#include "cutlass/util/device_memory.h" +#include "cutlass/util/host_tensor.h" + + +using Arch = cutlass::arch::Sm80; +static constexpr int kMaxK = 128; + +template +struct DefaultKernel { + // Some heuristics to select the best kernel (tested on Sm60, Sm70, Sm80) + // NOTE: Requires quite a lot of shmem for Sm80+, + // so might require tweaking those manually for Sm86/Sm89 + + static constexpr bool kSupports64x128 = + ArchTag::kMinComputeCapability >= 80 || + (ArchTag::kMinComputeCapability >= 70 && + cutlass::sizeof_bits::value <= 16); + static constexpr int kBlockSizeI = kSupports64x128 && kMaxK > 64 ? 128 : 64; + static constexpr bool kIsHalf = cutlass::sizeof_bits::value <= 16; + static constexpr bool kOutputInRF = kIsHalf && kMaxK <= kBlockSizeI; + static constexpr bool kPreload = kIsHalf && ArchTag::kMinComputeCapability >= 80 && kOutputInRF; + static constexpr int kBlockSizeJ = kPreload && kMaxK > 64 ? 128 : 64; + + using Kernel = AttentionBackwardKernel< + Arch, + Element, + true, // kIsAligned_ + false, // kApplyDropout_ + kPreload, // kPreload_ + kBlockSizeI, // kBlockSizeI_, + kBlockSizeJ, // kBlockSizeJ_, + kMaxK, // kMaxK + false, // kKeysQueriesAlignedToBlockSize + true // kEnableSplitKeys + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace { +template struct TypeName; +template <> struct TypeName { static constexpr const char* Name = "f32"; }; +template <> struct TypeName { static constexpr const char* Name = "f16"; }; +template <> struct TypeName { static constexpr const char* Name = "b16"; }; + +void readExpect(std::string const& expected) { + std::string read; + std::cin >> read; + if (read != expected) { + std::cerr << "FATAL: Read '" << read << "' but expected '" << expected << "'" << std::endl; + std::exit(1); + } +} + +/// Helpers to read from stdin +template +cutlass::HostTensor readTensorOnDevice(std::string const& expectedName) { + readExpect("tensor_begin"); + readExpect(std::string(TypeName::Name) + ":" + expectedName); + uint64_t len = 0; + std::cin >> len; + readExpect("file"); + std::string filename; + std::cin >> filename; + + cutlass::HostTensor tensor({int64_t(1), int64_t(len / sizeof(Element))}); + uint8_t* data = (uint8_t*)tensor.host_data(); + + std::fstream myFile(filename, std::ios::in | std::ios::binary ); + myFile.read((char*)data, len); + readExpect("tensor_end"); + tensor.sync_device(); + return tensor; +} + +int64_t readInt64(std::string const& expectedName) { + readExpect(expectedName); + int64_t s = 0; + std::cin >> s; + return s; +} + +float readFloat(std::string const& expectedName) { + readExpect(expectedName); + float s = 0; + std::cin >> s; + return s; +} + +// Writing +template +void writeTensor(std::string const& name, cutlass::HostTensor& tensor) { + tensor.sync_host(); // device->host + size_t u8len = tensor.size() * sizeof(Element); + + // Python is expected to provide a file name to write to + readExpect("tmpfile"); + std::string tmpfile; + std::cin >> tmpfile; + + uint8_t* data = (uint8_t*)tensor.host_data(); + std::fstream myFile(tmpfile, std::ios::out | std::ios::binary ); + myFile.write((char*)data, u8len); + myFile.close(); + + std::cout << "tensor_begin " << TypeName::Name << ":" << name << " "; + std::cout << u8len << " file " << tmpfile << " tensor_end" << std::endl; +} + +void writeInt64(std::string const& name, int64_t value) { + std::cout << name << " " << value << std::endl; +} +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +int runKernel() { + using Kernel = typename DefaultKernel::Kernel; + +#define READ_I64(NAME) p.NAME = (decltype(p.NAME))readInt64(#NAME) +#define READ_TENSOR_AND_STRIDES_BMH(DT, NAME, NAME_XS) \ + auto storage##NAME = readTensorOnDevice
(#NAME); \ + p.NAME##_ptr = storage##NAME.device_data(); \ + READ_I64(NAME_XS##_strideB); \ + READ_I64(NAME_XS##_strideM); \ + READ_I64(NAME_XS##_strideH); + +#define CUDA_CHECK(FN) { \ + auto cudaError = FN; \ + if (cudaError != cudaSuccess) { \ + std::cerr << "FATAL: " #FN " failed: " << cudaGetErrorString(cudaError) << std::endl; \ + return -1; \ + } \ +} + + typename Kernel::Params p; + p.scale = readFloat("scale"); + READ_I64(head_dim); + READ_I64(head_dim_value); + READ_I64(num_queries); + READ_I64(num_keys); + READ_I64(num_heads); + READ_I64(custom_mask_type); + READ_I64(num_batches); + int64_t repeat_count = readInt64("repeat_count"); + READ_I64(num_splits_key); + + READ_TENSOR_AND_STRIDES_BMH(Element, query, q); + READ_TENSOR_AND_STRIDES_BMH(Element, key, k); + READ_TENSOR_AND_STRIDES_BMH(Element, value, v); + auto lse = readTensorOnDevice("logsumexp"); + p.logsumexp_ptr = lse.device_data(); + p.lse_strideB = readInt64("lse_strideB"); + p.lse_strideH = readInt64("lse_strideH"); + + // output + auto stOutput = readTensorOnDevice("output"); + p.output_ptr = stOutput.device_data(); + READ_I64(o_strideB); + auto o_strideM = readInt64("o_strideM"); + if (o_strideM != p.o_strideM()) { + std::cerr << "Invalid `o_strideM`: " << o_strideM << " - expected " << p.o_strideM(); + return 2; + } + READ_I64(o_strideH); + + READ_TENSOR_AND_STRIDES_BMH(Element, grad_output, gO); + + auto stDelta = readTensorOnDevice("delta"); + p.delta_ptr = stDelta.device_data(); + READ_I64(delta_strideB); + READ_I64(delta_strideH); + + // Allocate workspace + if (p.workspace_size()) { + cudaMalloc(&p.workspace, p.workspace_size()); + } + + // Allocate outputs in BMHK format + p.gQKV_strideM_multiplier = 1; + p.gQ_strideH = p.head_dim; + p.gQ_strideB = p.gQ_strideM() * p.num_queries; + p.gK_strideH = p.head_dim; + p.gK_strideB = p.gK_strideM() * p.num_keys; + p.gV_strideH = p.head_dim_value; + p.gV_strideB = p.gV_strideM() * p.num_keys; + + cutlass::HostTensor gQ({int64_t(1), p.gQ_strideB * p.num_batches}); + cutlass::HostTensor gK({int64_t(1), p.gK_strideB * p.num_batches}); + cutlass::HostTensor gV({int64_t(1), p.gV_strideB * p.num_batches}); + p.grad_query_ptr = gQ.device_data(); + p.grad_key_ptr = gK.device_data(); + p.grad_value_ptr = gV.device_data(); + + if (!Kernel::check_supported(p)) { + std::cerr << "FATAL: Kernel does not support these inputs" << std::endl; + return 2; + } + + // Run kernel + cudaDeviceSynchronize(); + auto kernel_fn = attention_kernel_backward_batched_impl; + size_t smem_bytes = sizeof(typename Kernel::SharedStorage); + CUDA_CHECK(cudaFuncSetAttribute(kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, int(smem_bytes))); + kernel_fn<<>>(p); + + // Write outputs + std::cout << "OK "; + writeTensor("grad_query", gQ); + writeInt64("gQ_strideB", p.gQ_strideB); + writeInt64("gQ_strideM", p.gQ_strideM()); + writeInt64("gQ_strideH", p.gQ_strideH); + writeTensor("grad_key", gK); + writeInt64("gK_strideB", p.gK_strideB); + writeInt64("gK_strideM", p.gK_strideM()); + writeInt64("gK_strideH", p.gK_strideH); + writeTensor("grad_value", gV); + writeInt64("gV_strideB", p.gV_strideB); + writeInt64("gV_strideM", p.gV_strideM()); + writeInt64("gV_strideH", p.gV_strideH); + + // Timing + cudaEvent_t events[2]; + for (auto & event : events) { + CUDA_CHECK(cudaEventCreate(&event)); + } + CUDA_CHECK(cudaEventRecord(events[0])); + for (int i = 0; i < repeat_count; ++i) { + kernel_fn<<>>(p); + } + CUDA_CHECK(cudaEventRecord(events[1])); + CUDA_CHECK(cudaEventSynchronize(events[1])); + // Measure elapsed runtime + float runtime_ms = 0; + CUDA_CHECK(cudaEventElapsedTime(&runtime_ms, events[0], events[1])); + + std::cout << "runtime_ms " << runtime_ms / float(repeat_count) << std::endl; + return 0; +} + +int main() { + std::ios_base::sync_with_stdio(false); + + std::string dtype; + std::cin >> dtype; + std::cerr << "Running kernel with dtype: " << dtype << std::endl; + if (dtype == "f16") { + return runKernel(); + } else if (dtype == "b16") { + return runKernel(); + } else if (dtype == "f32") { + return runKernel(); + } else { + std::cerr << "FATAL: Unknown dtype: " << dtype << std::endl; + return 3; + } +} +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/41_fused_multi_head_attention/fused_multihead_attention_fixed_seqlen.cu b/examples/41_fused_multi_head_attention/fused_multihead_attention_fixed_seqlen.cu new file mode 100644 index 0000000000..cf02a7b933 --- /dev/null +++ b/examples/41_fused_multi_head_attention/fused_multihead_attention_fixed_seqlen.cu @@ -0,0 +1,1110 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief CUTLASS Attention Example. + + This workload computes a fused multi head attention. + Because it keeps the attention matrix in shared memory, it's both faster and + uses less global memory. + + This is based on `"Self-Attention Does Not Need O(n^2) Memory" `_, + and very similar to `"FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness" `_. + + Algorithm: + In short, we can compute the output incrementally in blocks of size B, + we just need to divide the final result by the sum of all coefficients in + the softmax (which we compute incrementally) with the following pseudo-code: + + ``` + s_prime = torch.zeros([num_queries, B]) + O = torch.zeros([num_queries, head_size_v]) + for i in range(0, K.shape[0], B): + si = exp((Q . K[i * B:(i+1) * B].t) * scale) + sum_coefs += attn_unscaled.sum(-1) + O += si . V[i * B:(i+1) * B] + O = O / s_prime + ``` + + In practice, and for numerical stability reasons, + we also substract the maximum so far (`mi`) before doing + the exponential. When we encounter new keys, the maximum + used to compute O so far (`m_prime`) can differ from the + current maximum, so we update O before accumulating with + + ``` + O = O * exp(m_prime - mi) + m_prime = mi + ``` + + Implementation details: + - `si` is stored in shared memory between the 2 back to back gemms + - we keep and accumulate the output + directly in registers if we can (`head_size_v <= 128`). + Otherwise, we store it & accumulate in global memory (slower) + - blocks are parallelized across the batch dimension, the number + of heads, and the query sequence size + + + Examples: + + # Run an attention example with default setup + $ ./examples/41_fused_multi_head_attention/41_fused_multi_head_attention_fixed_seqlen + + # Run an attention example with custom setup + $ ./examples/41_fused_multi_head_attention/41_fused_multi_head_attention_fixed_seqlen --head_number=2 --batch_size=3 --head_size=32 --head_size_v=64 --seq_length=512 --seq_length_kv=1024 --causal=true + + Acknowledgement: Fixed-sequence-length FMHA code was upstreamed by Meta xFormers (https://github.com/facebookresearch/xformers). +*/ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/gemm_grouped.h" +#include "cutlass/gemm/kernel/default_gemm_grouped.h" +#include "cutlass/gemm/device/gemm_grouped.h" +#include "cutlass/gemm/device/gemm_universal.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm_complex.h" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_norm.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/gemm/kernel/gemm_grouped.h" +#include "cutlass/gemm/kernel/gemm_transpose_operands.h" +#include "cutlass/gemm/kernel/default_gemm.h" +#include "cutlass/gemm/kernel/default_gemm_complex.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/gemm/gemm.h" + +#include "cutlass/epilogue/threadblock/epilogue_with_visitor.h" +#include "cutlass/fast_math.h" +#include "kernel_forward.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Result structure +struct Result { + + double runtime_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + // + // Methods + // + + Result( + double runtime_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess + ): + runtime_ms(runtime_ms), gflops(gflops), status(status), error(error), passed(true) { } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + bool error; + bool reference_check; + bool use_mask; + bool causal; + + std::vector problem_sizes0; + std::vector problem_sizes1; + + std::vector problem_sizes0_real; + std::vector problem_sizes1_real; + + int alignment; + int head_number; + int batch_size; + int head_size; + int head_size_v; + int seq_length; + int seq_length_kv; + int iterations; + + // alpha0, alpha1 and beta are fixed + // in this multi-head attention example + float alpha0; + float alpha1; + float beta; + + // + // Methods + // + + Options(): + help(false), + error(false), + alignment(1), + reference_check(true), + head_number(12), + batch_size(16), + head_size(64), + head_size_v(64), + seq_length(1024), + seq_length_kv(1024), + use_mask(false), + iterations(20), + causal(false) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("alignment", alignment, 1); + cmd.get_cmd_line_argument("head_number", head_number, 12); + cmd.get_cmd_line_argument("batch_size", batch_size, 16); + cmd.get_cmd_line_argument("head_size", head_size, 64); + cmd.get_cmd_line_argument("head_size_v", head_size_v, head_size); + cmd.get_cmd_line_argument("seq_length", seq_length, 1024); + cmd.get_cmd_line_argument("seq_length_kv", seq_length_kv, seq_length); + cmd.get_cmd_line_argument("use_mask", use_mask, false); + cmd.get_cmd_line_argument("iterations", iterations, 20); + cmd.get_cmd_line_argument("reference-check", reference_check, true); + cmd.get_cmd_line_argument("causal", causal, true); + + randomize_problems(); + + } + + void randomize_problems() { + + int problem_count = head_number * batch_size; + + problem_sizes0.reserve(problem_count); + problem_sizes1.reserve(problem_count); + + // When using mask, the original inputs are not padded + // and we need to save these info. + if (use_mask) { + problem_sizes0_real.reserve(problem_count); + problem_sizes1_real.reserve(problem_count); + } + + for (int i = 0; i < batch_size; ++i) { + // problems belonging to the same batch share the same seq len + int m_real = seq_length; + int mkv_real = seq_length_kv; + int m = (m_real + alignment - 1) / alignment * alignment; + int mkv = (mkv_real + alignment - 1) / alignment * alignment; + int k0 = head_size; + int k1 = head_size_v; + + for (int j = 0; j < head_number; ++j) { + cutlass::gemm::GemmCoord problem0(m, mkv, k0); + cutlass::gemm::GemmCoord problem1(m, k1, mkv); + problem_sizes0.push_back(problem0); + problem_sizes1.push_back(problem1); + + if (use_mask) { + cutlass::gemm::GemmCoord problem0_real(m_real, mkv_real, k0); + cutlass::gemm::GemmCoord problem1_real(m_real, k1, mkv_real); + problem_sizes0_real.push_back(problem0_real); + problem_sizes1_real.push_back(problem1_real); + } + } + } + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "41_fused_multi_head_attention_fixed_seqlen\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement.\n\n" + << " --head_number= Head number in multi-head attention (default: --head_number=12)\n" + << " --batch_size= Batch size in multi-head attention (default: --batch_size=16)\n" + << " --head_size= Head size in multi-head attention (default: --head_size=64)\n" + << " --head_size_v= Head size in multi-head attention for V (default: --head_size_v=head_size)\n" + << " --seq_length= Sequence length in multi-head attention for Q (default: --seq_length=1024)\n" + << " --seq_length_kv= Sequence length in multi-head attention for K/V (default: --seq_length_kv=seq_length)\n" + << " --use_mask= If true, performs padding-like masking in softmax.\n" + << " --iterations= Number of profiling iterations to perform.\n" + << " --reference-check= If true, performs reference check.\n" + << " --causal= If true, uses causal masking.\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const { + + // Number of real-valued multiply-adds + int64_t fops = int64_t(); + + for (size_t i = 0; i < problem_sizes0.size(); ++i) { + auto const& problem0 = problem_sizes0[i]; + auto const& problem1 = problem_sizes1[i]; + for (int row = 0; row < problem0.m(); ++row) { + int num_cols0 = problem0.n(); + if (causal) { + num_cols0 = std::min(row + 1, num_cols0); + } + // P <- Q . K_t + fops += 2 * num_cols0 * problem0.k(); + // P <- exp(P - max(P)) + fops += 2 * num_cols0; + // S <- sum(P) + fops += num_cols0 - 1; + // O <- P . V + fops += 2 * num_cols0 * problem1.n(); + // O <- O / S + fops += num_cols0 * problem1.n(); + } + } + + return double(fops) / double(1.0e9) / runtime_s; + } +}; + + + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class TestbedAttention { +public: + + // + // Type definitions + // + + using ElementQ = typename Attention::scalar_t; + using ElementK = typename Attention::scalar_t; + using ElementP = typename Attention::accum_t; + using ElementAccumulator = typename Attention::accum_t; + using ElementV = typename Attention::scalar_t; + using ElementO = typename Attention::output_t; + + using ElementCompute = typename Attention::accum_t; + + using ElementNorm = typename Attention::accum_t; + using ElementSum = typename Attention::accum_t; + using ElementSoftmaxCompute = typename Attention::accum_t; + + using LayoutQ = cutlass::layout::RowMajor; + using LayoutK = cutlass::layout::ColumnMajor; + using LayoutP = cutlass::layout::RowMajor; + using LayoutV = cutlass::layout::RowMajor; + using LayoutO = cutlass::layout::RowMajor; + + using MatrixCoord = typename LayoutP::TensorCoord; + +private: + + // + // Data members + // + + Options & options; + + /// Initialization + cutlass::Distribution::Kind init_Q; + cutlass::Distribution::Kind init_K; + cutlass::Distribution::Kind init_P; + cutlass::Distribution::Kind init_V; + cutlass::Distribution::Kind init_O; + uint32_t seed; + + cutlass::DeviceAllocation problem_sizes_device0; + cutlass::DeviceAllocation problem_sizes_device1; + cutlass::DeviceAllocation problem_sizes_device0_real; + + std::vector offset_Q; + std::vector offset_K; + std::vector offset_P; + std::vector offset_V; + std::vector offset_O; + + std::vector ldq_host; + std::vector ldk_host; + std::vector ldp_host; + std::vector ldv_host; + std::vector ldo_host; + std::vector seqlen_host; + + cutlass::DeviceAllocation ldq; + cutlass::DeviceAllocation ldk; + cutlass::DeviceAllocation ldp; + cutlass::DeviceAllocation ldv; + cutlass::DeviceAllocation ldo; + cutlass::DeviceAllocation seqlen; + + cutlass::DeviceAllocation block_Q; + cutlass::DeviceAllocation block_K; + cutlass::DeviceAllocation block_P; + cutlass::DeviceAllocation block_V; + cutlass::DeviceAllocation block_O; + cutlass::DeviceAllocation block_Norm; + cutlass::DeviceAllocation block_Sum; + + cutlass::DeviceAllocation offset_P_Device; + + cutlass::DeviceAllocation ptr_Q; + cutlass::DeviceAllocation ptr_K; + cutlass::DeviceAllocation ptr_P; + cutlass::DeviceAllocation ptr_V; + cutlass::DeviceAllocation ptr_O; + +public: + + // + // Methods + // + + TestbedAttention( + Options &options_, + cutlass::Distribution::Kind init_Q_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_K_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_P_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_V_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_O_ = cutlass::Distribution::Uniform, + uint32_t seed_ = 3080 + ): + options(options_), init_Q(init_Q_), init_K(init_K_), init_P(init_P_), init_V(init_V_), init_O(init_O_), seed(seed_) { } + + int problem_count() const { + return (options.head_number * options.batch_size); + } + +private: + + /// Helper to initialize a tensor view + template + void initialize_tensor_( + Element *ptr, + size_t capacity, + cutlass::Distribution::Kind dist_kind, + uint32_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + Element scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + scope_max = 8; + scope_min = -8; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::device::BlockFillRandomUniform( + ptr, capacity, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::device::BlockFillRandomGaussian( + ptr, capacity, seed, Element(), Element(0.5f)); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + // Fill with increasing elements + cutlass::reference::device::BlockFillSequential( + ptr, capacity, Element(1), Element()); + } + else { + + // Fill with all 1s + cutlass::reference::device::BlockFillSequential( + ptr, capacity, Element(), Element(1)); + } + } + + /// Initializes data structures + void initialize_() { + + // + // Set scalors for the mha example + // + + options.alpha0 = 1.0f / sqrt(float(options.head_size)); + options.alpha1 = 1.0f; + options.beta = 0; + + // + // Choose random problem sizes + // + + // construct a few problems of random sizes + srand(seed); + + int64_t total_elements_Q = 0; + int64_t total_elements_K = 0; + int64_t total_elements_P = 0; + int64_t total_elements_V = 0; + int64_t total_elements_O = 0; + + ldq_host.resize(problem_count()); + ldk_host.resize(problem_count()); + ldp_host.resize(problem_count()); + ldv_host.resize(problem_count()); + ldo_host.resize(problem_count()); + seqlen_host.resize(problem_count()); + + // Create tensors in BMHK format, where + // B = batch_size + // M = sequence length + // H = num_heads + // K = embedding size per head + int64_t batch_offset_Q, batch_offset_K, batch_offset_V, batch_offset_O; + + for (int32_t b = 0; b < options.batch_size; ++b) { + batch_offset_Q = total_elements_Q; + batch_offset_K = total_elements_K; + batch_offset_V = total_elements_V; + batch_offset_O = total_elements_O; + for (int32_t h = 0; h < options.head_number; ++h) { + int32_t i = h + b * options.head_number; + + auto problem0 = options.problem_sizes0.at(i); + auto problem1 = options.problem_sizes1.at(i); + + ldq_host.at(i) = LayoutQ::packed({problem0.m(), options.head_number * problem0.k()}).stride(0); + ldk_host.at(i) = LayoutK::packed({options.head_number * problem0.k(), problem0.n()}).stride(0); + ldp_host.at(i) = LayoutP::packed({problem0.m(), problem0.n()}).stride(0); + ldv_host.at(i) = LayoutV::packed({problem1.k(), options.head_number * problem1.n()}).stride(0); + ldo_host.at(i) = LayoutO::packed({problem1.m(), options.head_number * problem1.n()}).stride(0); + + // m = n for attention problems. + seqlen_host.at(i) = problem0.m(); + + offset_Q.push_back(batch_offset_Q + h * problem0.k()); + offset_K.push_back(batch_offset_K + h * problem0.k()); + offset_P.push_back(total_elements_P); + offset_V.push_back(batch_offset_V + h * problem0.k()); + offset_O.push_back(batch_offset_O + h * problem1.n()); + + int64_t elements_Q = problem0.m() * problem0.k(); + int64_t elements_K = problem0.k() * problem0.n(); + int64_t elements_P = problem0.m() * problem0.n(); + int64_t elements_V = problem1.k() * problem1.n(); + int64_t elements_O = problem1.m() * problem1.n(); + + total_elements_Q += elements_Q; + total_elements_K += elements_K; + total_elements_P += elements_P; + total_elements_V += elements_V; + total_elements_O += elements_O; + } + } + + problem_sizes_device0.reset(problem_count()); + problem_sizes_device1.reset(problem_count()); + problem_sizes_device0.copy_from_host(options.problem_sizes0.data()); + problem_sizes_device1.copy_from_host(options.problem_sizes1.data()); + + if (options.use_mask) { + problem_sizes_device0_real.reset(problem_count()); + problem_sizes_device0_real.copy_from_host(options.problem_sizes0_real.data()); + } + + ldq.reset(problem_count()); + ldk.reset(problem_count()); + ldp.reset(problem_count()); + ldv.reset(problem_count()); + ldo.reset(problem_count()); + seqlen.reset(problem_count()); + + ldq.copy_from_host(ldq_host.data()); + ldk.copy_from_host(ldk_host.data()); + ldp.copy_from_host(ldp_host.data()); + ldv.copy_from_host(ldv_host.data()); + ldo.copy_from_host(ldo_host.data()); + seqlen.copy_from_host(seqlen_host.data()); + + // + // Assign pointers + // + + block_Q.reset(total_elements_Q); + block_K.reset(total_elements_K); + block_P.reset(total_elements_P); + block_V.reset(total_elements_V); + block_O.reset(total_elements_O); + + offset_P_Device.reset(problem_count()); + + // sync offset with device + cutlass::device_memory::copy_to_device(offset_P_Device.get(), offset_P.data(), offset_P.size()); + + std::vector ptr_Q_host(problem_count()); + std::vector ptr_K_host(problem_count()); + std::vector ptr_P_host(problem_count()); + std::vector ptr_V_host(problem_count()); + std::vector ptr_O_host(problem_count()); + std::vector ptr_norm_host(problem_count()); + std::vector ptr_sum_host(problem_count()); + + for (int32_t i = 0; i < problem_count(); ++i) { + ptr_Q_host.at(i) = block_Q.get() + offset_Q.at(i); + ptr_K_host.at(i) = block_K.get() + offset_K.at(i); + ptr_P_host.at(i) = block_P.get() + offset_P.at(i); + ptr_V_host.at(i) = block_V.get() + offset_V.at(i); + ptr_O_host.at(i) = block_O.get() + offset_O.at(i); + } + + ptr_Q.reset(problem_count()); + ptr_Q.copy_from_host(ptr_Q_host.data()); + + ptr_K.reset(problem_count()); + ptr_K.copy_from_host(ptr_K_host.data()); + + ptr_P.reset(problem_count()); + ptr_P.copy_from_host(ptr_P_host.data()); + + ptr_V.reset(problem_count()); + ptr_V.copy_from_host(ptr_V_host.data()); + + ptr_O.reset(problem_count()); + ptr_O.copy_from_host(ptr_O_host.data()); + + // + // Initialize the problems of the workspace + // + + initialize_tensor_(block_Q.get(), total_elements_Q, init_Q, seed + 1); + initialize_tensor_(block_K.get(), total_elements_K, init_K, seed + 2); + initialize_tensor_(block_V.get(), total_elements_V, init_V, seed + 3); + + } + + template + bool verify_tensor_(std::vector vector_Input, \ + std::vector vector_Input_Ref, + int64_t verify_length = -1) { + + int64_t size = (vector_Input.size() < vector_Input_Ref.size()) ? vector_Input.size() : vector_Input_Ref.size(); + size = (verify_length == -1) ? size : verify_length; + + // 0.05 for absolute error + float abs_tol = 5e-2f; + // 10% for relative error + float rel_tol = 1e-1f; + for (int64_t i = 0; i < size; ++i) { + float diff = (float)(vector_Input.at(i) - vector_Input_Ref.at(i)); + float abs_diff = fabs(diff); + float abs_ref = fabs((float)vector_Input_Ref.at(i) + 1e-5f); + float relative_diff = abs_diff / abs_ref; + if ( (isnan(vector_Input_Ref.at(i)) || isnan(abs_diff) || isinf(abs_diff)) || (abs_diff > abs_tol && relative_diff > rel_tol)) { + printf("[%d/%d] diff = %f, rel_diff = %f, {computed=%f, ref=%f}.\n", int(i), int(size), abs_diff, relative_diff, (float)(vector_Input.at(i)), (float)(vector_Input_Ref.at(i))); + return false; + } + + } + + return true; + } + + /// Verifies the result is a GEMM + bool verify_() { + + bool passed = true; + + for (int32_t b = 0; b < options.batch_size; ++b) { + int32_t i = b * options.head_number; + // Problem size is the same for all heads + cutlass::gemm::GemmCoord problem0 = options.problem_sizes0.at(b * options.head_number); + cutlass::gemm::GemmCoord problem1 = options.problem_sizes1.at(b * options.head_number); + + MatrixCoord extent_Q{problem0.m(), problem0.k()}; + MatrixCoord extent_K{problem0.k(), problem0.n()}; + MatrixCoord extent_P{problem0.m(), problem0.n()}; + MatrixCoord extent_V{problem1.k(), problem1.n()}; + MatrixCoord extent_O{problem1.m(), problem1.n()}; + + LayoutO layout_O(ldo_host.at(i)); + std::vector matrix_O(layout_O.capacity(extent_O)); + cutlass::device_memory::copy_to_host(matrix_O.data(), block_O.get() + offset_O.at(i), matrix_O.size()); + cutlass::DeviceAllocation block_Ref_O(layout_O.capacity(extent_O)); + + for (int32_t h = 0; h < options.head_number; ++h) { + i = h + b * options.head_number; + + LayoutQ layout_Q(ldq_host.at(i)); + LayoutK layout_K(ldk_host.at(i)); + LayoutP layout_P(ldp_host.at(i)); + LayoutV layout_V(ldv_host.at(i)); + + cutlass::TensorView view_Q(block_Q.get() + offset_Q.at(i), layout_Q, extent_Q); + cutlass::TensorView view_K(block_K.get() + offset_K.at(i), layout_K, extent_K); + cutlass::TensorView view_V(block_V.get() + offset_V.at(i), layout_V, extent_V); + cutlass::TensorView view_Ref_O_device(block_Ref_O.get() + offset_O.at(i) - offset_O.at(b * options.head_number), layout_O, extent_O); + + cutlass::DeviceAllocation block_Ref_P(layout_P.capacity(extent_P)); + cutlass::TensorView view_Ref_P_device(block_Ref_P.get(), layout_P, extent_P); + + // Reference GEMM + cutlass::reference::device::GemmComplex< + ElementQ, LayoutQ, + ElementK, LayoutK, + ElementP, LayoutP, + ElementCompute, ElementAccumulator + >( + problem0, + ElementAccumulator(options.alpha0), + view_Q, + Attention::MM0::Mma::kTransformA, + view_K, + Attention::MM0::Mma::kTransformB, + ElementAccumulator(options.beta), + view_Ref_P_device, + view_Ref_P_device, + ElementAccumulator(0) + ); + + // Compute softmax for P. We need to explicitly compute softmax + // over P because softmax is fused to the second GEMM in the + // profiled implementation. + std::vector matrix_Ref(layout_P.capacity(extent_P)); + cutlass::device_memory::copy_to_host(matrix_Ref.data(), block_Ref_P.get(), matrix_Ref.size()); + cutlass::TensorView view_Ref_host(matrix_Ref.data(), layout_P, extent_P); + std::vector vector_Norm_Ref(problem0.m()); + std::vector vector_Sum_Ref(problem0.m()); + + int n_dim = options.use_mask ? options.problem_sizes0_real.at(i).n() : problem0.n(); + + // Compute softmax for reference matrix + for (int m = 0; m < problem0.m(); m++) { + int n_dim_row = n_dim; + if (options.causal) { + n_dim_row = std::min(m + 1, n_dim); + } + ElementSoftmaxCompute max = ElementSoftmaxCompute(view_Ref_host.ref().at({m, 0})); + for (int n = 1; n < n_dim_row; n++) { + max = std::max(max, ElementSoftmaxCompute(view_Ref_host.ref().at({m, n}))); + } + + vector_Norm_Ref.at(m) = ElementNorm(max); + + ElementSoftmaxCompute sum = ElementSoftmaxCompute(); + for (int n = 0; n < n_dim_row; n++) { + sum += std::exp( ElementSoftmaxCompute(view_Ref_host.ref().at({m, n})) - max ); + } + ElementSoftmaxCompute inv_sum = ElementSoftmaxCompute(1.0f / sum); + + vector_Sum_Ref.at(m) = ElementSum(inv_sum); + + for (int n = 0; n < n_dim_row; n++) { + view_Ref_host.ref().at({m, n}) = ElementP( + std::exp( ElementSoftmaxCompute(view_Ref_host.ref().at({m, n})) - max ) * inv_sum + ); + } + // Mask out the rest of the attention matrix + for (int n = n_dim_row; n < n_dim; ++n) { + view_Ref_host.ref().at({m, n}) = ElementP(0); + } + } + + // when not using mask, problem_real and problem share the same sizes + if (options.use_mask) { + for (int m = 0; m < problem0.m(); m++) { + for (int n = n_dim; n < problem0.n(); n++) { + view_Ref_host.ref().at({m, n}) = ElementP(0); + } + } + } + + cutlass::device_memory::copy_to_device(block_Ref_P.get(), matrix_Ref.data(), matrix_Ref.size()); + + // Reference GEMM + cutlass::reference::device::GemmComplex< + ElementP, LayoutP, + ElementV, LayoutV, + ElementO, LayoutO, + ElementCompute, ElementAccumulator + >( + problem1, + ElementAccumulator(options.alpha1), + view_Ref_P_device, + Attention::MM0::Mma::kTransformA, + view_V, + Attention::MM0::Mma::kTransformB, + ElementAccumulator(options.beta), + view_Ref_O_device, + view_Ref_O_device, + ElementAccumulator(0) + ); + } + + // Copy to host memory + std::vector matrix_Ref_O(layout_O.capacity(extent_O)); + cutlass::device_memory::copy_to_host(matrix_Ref_O.data(), block_Ref_O.get(), matrix_Ref_O.size()); + + // printf("Pb %d: \n Q=(offset=%d, ldq=%d)\n K=(offset=%d, ldk=%d)\n O=(offset=%d, ldo=%d)\n", + // int(i), int(offset_Q[i]), int(ldq_host[i]), int(offset_K[i]), int(ldk_host[i]), int(offset_O[i]), int(ldo_host[i])); + + bool verified_O = false; + + if (!verified_O) { + verified_O = verify_tensor_(matrix_O, matrix_Ref_O); + } + + passed = passed && verified_O; + + if (!passed) { + std::cerr << "\n***\nError - problem " << i << " (batch " << b << ") failed the QA check\n***\n" << std::endl; + + if (!verified_O) { + std::cout << "Final matrix output is incorrect" << std::endl; + } + + return passed; + } + } + + return passed; + } + +public: + + + /// Executes a CUTLASS Attention kernel and measures runtime. + Result profile() { + + Result result; + result.passed = false; + + // Initialize the problem + initialize_(); + + typename Attention::Params p; + { // set parameters + p.query_ptr = block_Q.get(); + p.key_ptr = block_K.get(); + p.value_ptr = block_V.get(); + p.logsumexp_ptr = nullptr; // Only needed for bw + p.output_accum_ptr = nullptr; + if (Attention::kNeedsOutputAccumulatorBuffer) { + cudaMalloc(&p.output_accum_ptr, block_O.size() * sizeof(typename Attention::output_accum_t)); + } + p.output_ptr = block_O.get(); + + // TODO: support arbitrary seq lengths + // if (cu_seqlens_q.has_value()) { + // p.cu_seqlens_q_ptr = (int32_t*)cu_seqlens_q->data_ptr(); + // p.cu_seqlens_k_ptr = (int32_t*)cu_seqlens_k->data_ptr(); + // } + + p.scale = options.alpha0; + + p.num_heads = options.head_number; + p.num_batches = options.batch_size; + p.head_dim = options.head_size; + p.head_dim_value = options.head_size_v; + p.num_queries = options.seq_length; + p.num_keys = options.seq_length_kv; + if (options.causal) { + p.custom_mask_type = Attention::CausalFromTopLeft; + } + + // All tensors are in BMHK shapes + p.q_strideH = options.head_size; + p.k_strideH = options.head_size; + p.v_strideH = options.head_size_v; + p.q_strideM = int32_t(ldq_host[0]); + p.k_strideM = int32_t(ldk_host[0]); + p.v_strideM = int32_t(ldv_host[0]); + p.q_strideB = p.q_strideM * options.seq_length; + p.k_strideB = p.k_strideM * options.seq_length_kv; + p.v_strideB = p.v_strideM * options.seq_length_kv; + p.o_strideM = p.head_dim_value * p.num_heads; + } + + // launch kernel :) + constexpr auto kernel_fn = attention_kernel_batched_impl; + int smem_bytes = sizeof(typename Attention::SharedStorage); + if (smem_bytes > 0xc000) { + cudaFuncSetAttribute(kernel_fn, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_bytes); + } + if (!Attention::check_supported(p)) { + std::cerr << "Kernel does not support these inputs" << std::endl; + return result; + } + kernel_fn<<>>(p); + + // Wait for completion + result.error = cudaDeviceSynchronize(); + + if (result.error != cudaSuccess) { + std::cerr << "Kernel execution error: " << cudaGetErrorString(result.error); + return result; + } + + // + // Verify correctness + // + result.passed = true; + + if (options.reference_check) { + result.passed = verify_(); + } + + // + // Warm-up run + // + + kernel_fn<<>>(p); + + if (result.status != cutlass::Status::kSuccess) { + std::cerr << "Failed to run CUTLASS Attention kernel." << std::endl; + return result; + } + + // + // Construct events + // + + cudaEvent_t events[2]; + + for (auto & event : events) { + result.error = cudaEventCreate(&event); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; + return -1; + } + } + + // Record an event at the start of a series of GEMM operations + result.error = cudaEventRecord(events[0]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // + // Run profiling loop + // + + for (int iter = 0; iter < options.iterations; ++iter) { + kernel_fn<<>>(p); + } + + // + // Stop profiling loop + // + + // Record an event when the GEMM operations have been launched. + result.error = cudaEventRecord(events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Wait for work on the device to complete. + result.error = cudaEventSynchronize(events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Measure elapsed runtime + float runtime_ms = 0; + result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Compute average runtime and GFLOPs. + result.runtime_ms = double(runtime_ms) / double(options.iterations); + result.gflops = options.gflops(result.runtime_ms / 1000.0); + + // + // Cleanup + // + + for (auto event : events) { + (void)cudaEventDestroy(event); + } + + std::cout << std::endl; + std::cout << "CUTLASS Attention:\n" + << "====================================================" << std::endl; + std::cout << " " << " {seq length Q, seq length KV, head size, head size V, head number, batch size} = {" << options.seq_length \ + << ", " << options.seq_length_kv << ", " << options.head_size << ", " << options.head_size_v << ", " << options.head_number\ + << ", " << options.batch_size << "}." << std::endl; + std::cout << std::endl; + std::cout << " " << "Runtime: " << result.runtime_ms << " ms" << std::endl; + std::cout << " " << "GFLOPs: " << result.gflops << std::endl; + + return result; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + int kQueriesPerBlock, + int kKeysPerBlock, + int kMaxK +> +int run_attention(Options& options) { + using Attention = AttentionKernel< + cutlass::half_t, // scalar_t + cutlass::arch::Sm80, // ArchTag + true, // Memory is aligned + kQueriesPerBlock, + kKeysPerBlock, + kMaxK, + false, // Supports dropout + false // Supports bias + >; + + // + // Test and profile + // + + TestbedAttention testbed(options); + + Result result = testbed.profile(); + if (!result.passed) { + std::cout << "Profiling CUTLASS attention has failed.\n"; + std::cout << "\nFailed\n"; + return -1; + } + + std::cout << "\nPassed\n"; + return 0; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // + // This example uses mma.sync to directly access Tensor Cores to achieve peak performance. + // + + cudaDeviceProp props; + + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if (__CUDACC_VER_MAJOR__ < 11 || props.major < 8) { + + // + // This example requires an NVIDIA Ampere-architecture GPU. + // + + std::cout + << "CUTLASS's CUTLASS Attention example requires a GPU of NVIDIA's Ampere Architecture or " + << "later (compute capability 80 or greater).\n"; + + return 0; + } + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + + if (options.use_mask) { + std::cerr << "--use_mask is not supported at the moment\n"; + return -2; + } + if (options.alignment != 1) { + std::cerr << "--alignment=1 is the only supported value\n"; + return -2; + } + + // Determine kernel configuration based on head size. + // If head size is less than or equal to 64, each block operates over 64 queries and + // 64 keys, and partial results can be stored in the register file. + // If head size is greater than 64, each block operates over 32 queries and 128 keys, + // and partial results are stored in shared memory. + if (options.head_size_v > 64) { + static int const kQueriesPerBlock = 32; + static int const kKeysPerBlock = 128; + if (options.head_size_v <= 128) { + return run_attention(options); + } else { + return run_attention(options); + } + } else { + static constexpr int kMaxK = 64; // <- Decrease to 32/16 if your problem is smaller + static int const kQueriesPerBlock = 64; + static int const kKeysPerBlock = 64; + return run_attention(options); + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/41_fused_multi_head_attention/fused_multihead_attention_variable_seqlen.cu b/examples/41_fused_multi_head_attention/fused_multihead_attention_variable_seqlen.cu new file mode 100644 index 0000000000..49d8699a64 --- /dev/null +++ b/examples/41_fused_multi_head_attention/fused_multihead_attention_variable_seqlen.cu @@ -0,0 +1,1195 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief CUTLASS Attention Example. + + This workload computes a fused multi head attention that supports variable sequence lengths. + Because it keeps the attention matrix in shared memory, it's both faster and + uses less global memory. + + This is based on `"Self-Attention Does Not Need O(n^2) Memory" `_, + and very similar to `"FlashAttention: Fast and Memory-Efficient Exact Attention with IO-Awareness" `_. + + Algorithm: + In short, we can compute the output incrementally in blocks of size B, + we just need to divide the final result by the sum of all coefficients in + the softmax (which we compute incrementally) with the following pseudo-code: + + ``` + s_prime = torch.zeros([num_queries, B]) + O = torch.zeros([num_queries, head_size_v]) + for i in range(0, K.shape[0], B): + si = exp((Q . K[i * B:(i+1) * B].t) * scale) + sum_coefs += attn_unscaled.sum(-1) + O += si . V[i * B:(i+1) * B] + O = O / s_prime + ``` + + In practice, and for numerical stability reasons, + we also substract the maximum so far (`mi`) before doing + the exponential. When we encounter new keys, the maximum + used to compute O so far (`m_prime`) can differ from the + current maximum, so we update O before accumulating with + + ``` + O = O * exp(m_prime - mi) + m_prime = mi + ``` + + Implementation details: + - `si` is stored in shared memory between the 2 back to back gemms + - we keep and accumulate the output + directly in registers if we can (`head_size_v <= 128`). + Otherwise, we store it & accumulate in global memory (slower) + - blocks are parallelized across the batch dimension, the number + of heads, and the query sequence size + + + Examples: + + # Run an attention example with default setup + $ ./examples/41_fused_multi_head_attention/41_fused_multi_head_attention_variable_seqlen + + # Run an attention example with custom setup + $ ./examples/41_fused_multi_head_attention/41_fused_multi_head_attention_variable_seqlen --head_number=2 --batch_size=3 --head_size=32 --head_size_v=64 --seq_length=512 --seq_length_kv=1024 --causal=true + + Acknowledgement: Fixed-sequence-length FMHA code was upstreamed by Meta xFormers (https://github.com/facebookresearch/xformers). + Using grouped GEMM to handle variable sequence lengths is inspired by an idea originally prototyped by ByteDance Inc. +*/ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/device/gemm_grouped.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm_complex.h" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_norm.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/gemm/kernel/default_gemm.h" +#include "cutlass/gemm/kernel/default_gemm_complex.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/gemm/gemm.h" + +#include "cutlass/fast_math.h" + +#include "default_fmha_grouped.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Result structure +struct Result { + + double runtime_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + // + // Methods + // + + Result( + double runtime_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess + ): + runtime_ms(runtime_ms), gflops(gflops), status(status), error(error), passed(true) { } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + bool error; + bool reference_check; + bool use_mask; + bool causal; + bool fixed_seq_length; + + std::vector problem_sizes0; + std::vector problem_sizes1; + + std::vector problem_sizes0_real; + std::vector problem_sizes1_real; + + int alignment; + int head_number; + int batch_size; + int head_size; + int head_size_v; + int seq_length; + int seq_length_kv; + int iterations; + int problem_count; + + // alpha0, alpha1 and beta are fixed + // in this multi-head attention example + float alpha0; + float alpha1; + float beta; + + cutlass::gemm::kernel::GroupScheduleMode scheduler_mode; + + // + // Methods + // + + Options(): + help(false), + error(false), + alignment(1), + reference_check(true), + head_number(12), + batch_size(16), + head_size(64), + head_size_v(64), + seq_length(1024), + seq_length_kv(1024), + use_mask(false), + iterations(20), + causal(false), + fixed_seq_length(false), + problem_count(batch_size * head_number), + scheduler_mode(cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("alignment", alignment, 1); + cmd.get_cmd_line_argument("head_number", head_number, 12); + cmd.get_cmd_line_argument("batch_size", batch_size, 16); + cmd.get_cmd_line_argument("head_size", head_size, 64); + cmd.get_cmd_line_argument("head_size_v", head_size_v, head_size); + cmd.get_cmd_line_argument("seq_length", seq_length, 1024); + cmd.get_cmd_line_argument("seq_length_kv", seq_length_kv, seq_length); + cmd.get_cmd_line_argument("use_mask", use_mask, false); + cmd.get_cmd_line_argument("iterations", iterations, 20); + cmd.get_cmd_line_argument("reference-check", reference_check, true); + cmd.get_cmd_line_argument("causal", causal, true); + cmd.get_cmd_line_argument("fixed_seq_length", fixed_seq_length, false); + + std::vector scheduler_mode_strs; + cmd.get_cmd_line_arguments("scheduler-mode", scheduler_mode_strs); + + if (!scheduler_mode_strs.empty()) { + if (scheduler_mode_strs.size() > 1) { + std::cerr << "Only one scheduler mode may be passed in" << std::endl; + error = true; + return; + } + std::string scheduler_mode_str = scheduler_mode_strs[0]; + if (scheduler_mode_str == "kDeviceOnly") { + scheduler_mode = cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly; + } else if (scheduler_mode_str == "kHostPrecompute") { + scheduler_mode = cutlass::gemm::kernel::GroupScheduleMode::kHostPrecompute; + } else { + std::cerr << "Unrecognized scheduler mode '" << scheduler_mode_str << "'" << std::endl; + error = true; + return; + } + } + + if (fixed_seq_length) { + std::cout << "NOTE: Better performance is expected for fixed-sized sequence length from 41_fused_multi_head_attention_fixed_seqlen." << std::endl; + } + + randomize_problems(); + } + + void randomize_problems() { + + problem_count = head_number * batch_size; + + problem_sizes0.reserve(problem_count); + problem_sizes1.reserve(problem_count); + + // When using mask, the original inputs are not padded + // and we need to save these info. + if (use_mask) { + problem_sizes0_real.reserve(problem_count); + problem_sizes1_real.reserve(problem_count); + } + + for (int i = 0; i < batch_size; ++i) { + // problems belonging to the same batch share the same seq len + + int m_real, mkv_real; + if (fixed_seq_length) { + m_real = seq_length; + mkv_real = seq_length_kv; + } else { + m_real = (rand() % seq_length) + 1; + + // Only randomize seq_length_kv if it was set to a different value than + // seq_length originally. + if (seq_length != seq_length_kv) { + mkv_real = (rand() % seq_length_kv) + 1; + } else { + mkv_real = m_real; + } + } + + int m = (m_real + alignment - 1) / alignment * alignment; + int mkv = (mkv_real + alignment - 1) / alignment * alignment; + int k0 = head_size; + int k1 = head_size_v; + + for (int j = 0; j < head_number; ++j) { + cutlass::gemm::GemmCoord problem0(m, mkv, k0); + cutlass::gemm::GemmCoord problem1(m, k1, mkv); + + problem_sizes0.push_back(problem0); + problem_sizes1.push_back(problem1); + + if (use_mask) { + cutlass::gemm::GemmCoord problem0_real(m_real, mkv_real, k0); + cutlass::gemm::GemmCoord problem1_real(m_real, k1, mkv_real); + problem_sizes0_real.push_back(problem0_real); + problem_sizes1_real.push_back(problem1_real); + } + + } + } + } + + void print_problems() { + std::cout << " Running " << batch_size << " batches, each with " << head_number << " heads of size " << head_size << ":" << std::endl; + for (int i = 0; i < batch_size; ++i) { + int idx = i * head_number; + std::cout << " [" << i << "] seq_length = " << problem_sizes0[idx].m() << " seq_length_kv = " << problem_sizes0[idx].n() << std::endl; + } + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "41_fused_multi_head_attention_variable_seqlen\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement.\n\n" + << " --head_number= Head number in multi-head attention (default: --head_number=12)\n" + << " --batch_size= Batch size in multi-head attention (default: --batch_size=16)\n" + << " --head_size= Head size in multi-head attention (default: --head_size=64)\n" + << " --head_size_v= Head size in multi-head attention for V (default: --head_size_v=head_size)\n" + << " --seq_length= Sequence length in multi-head attention for Q (default: --seq_length=1024)\n" + << " --seq_length_kv= Sequence length in multi-head attention for K/V (default: --seq_length_kv=seq_length)\n" + << " --use_mask= If true, performs padding-like masking in softmax.\n" + << " --iterations= Number of profiling iterations to perform.\n" + << " --reference-check= If true, performs reference check.\n" + << " --causal= If true, uses causal masking.\n" + << " --fixed_seq_length= If true, uses the same sequence length for each item in the batch.\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const { + + // Number of real-valued multiply-adds + int64_t fops = int64_t(); + + for (size_t i = 0; i < problem_sizes0.size(); ++i) { + auto const& problem0 = problem_sizes0[i]; + auto const& problem1 = problem_sizes1[i]; + + for (int row = 0; row < problem0.m(); ++row) { + int num_cols0 = problem0.n(); + if (causal) { + num_cols0 = std::min(row + 1, num_cols0); + } + // P <- Q . K_t + fops += 2 * num_cols0 * problem0.k(); + // P <- exp(P - max(P)) + fops += 2 * num_cols0; + // S <- sum(P) + fops += num_cols0 - 1; + // O <- P . V + fops += 2 * num_cols0 * problem1.n(); + // O <- O / S + fops += num_cols0 * problem1.n(); + } + } + + return double(fops) / double(1.0e9) / runtime_s; + } +}; + + + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class TestbedAttention { +public: + + // + // Type definitions + // + + using scalar_t = typename Attention::GemmKernel::scalar_t; + using accum_t = typename Attention::GemmKernel::accum_t; + using output_t = typename Attention::GemmKernel::output_t; + using output_accum_t = typename Attention::GemmKernel::output_accum_t; + + using ElementQ = scalar_t; + using ElementK = scalar_t; + using ElementP = accum_t; + using ElementAccumulator = accum_t; + using ElementV = scalar_t; + using ElementO = output_t; + using ElementOAccum = output_accum_t; + + using ElementCompute = accum_t; + + using ElementNorm = accum_t; + using ElementSum = accum_t; + using ElementSoftmaxCompute = accum_t; + + using LayoutQ = cutlass::layout::RowMajor; + using LayoutK = cutlass::layout::ColumnMajor; + using LayoutP = cutlass::layout::RowMajor; + using LayoutV = cutlass::layout::RowMajor; + using LayoutO = cutlass::layout::RowMajor; + + using MatrixCoord = typename LayoutP::TensorCoord; + + static bool const kNeedsOutputAccumulatorBuffer = Attention::GemmKernel::kNeedsOutputAccumulatorBuffer; + +private: + + // + // Data members + // + + Options & options; + + /// Initialization + cutlass::Distribution::Kind init_Q; + cutlass::Distribution::Kind init_K; + cutlass::Distribution::Kind init_P; + cutlass::Distribution::Kind init_V; + cutlass::Distribution::Kind init_O; + uint32_t seed; + + cutlass::DeviceAllocation problem_sizes_device0; + cutlass::DeviceAllocation problem_sizes_device1; + cutlass::DeviceAllocation problem_sizes_device0_real; + + std::vector offset_Q; + std::vector offset_K; + std::vector offset_P; + std::vector offset_V; + std::vector offset_O; + + std::vector ldq_host; + std::vector ldk_host; + std::vector ldp_host; + std::vector ldv_host; + std::vector ldo_host; + std::vector seqlen_host; + + cutlass::DeviceAllocation ldq; + cutlass::DeviceAllocation ldk; + cutlass::DeviceAllocation ldp; + cutlass::DeviceAllocation ldv; + cutlass::DeviceAllocation ldo; + cutlass::DeviceAllocation seqlen; + + cutlass::DeviceAllocation block_Q; + cutlass::DeviceAllocation block_K; + cutlass::DeviceAllocation block_P; + cutlass::DeviceAllocation block_V; + cutlass::DeviceAllocation block_O; + cutlass::DeviceAllocation block_O_accumulate; + cutlass::DeviceAllocation block_Norm; + cutlass::DeviceAllocation block_Sum; + + cutlass::DeviceAllocation offset_P_Device; + + cutlass::DeviceAllocation ptr_Q; + cutlass::DeviceAllocation ptr_K; + cutlass::DeviceAllocation ptr_P; + cutlass::DeviceAllocation ptr_V; + cutlass::DeviceAllocation ptr_O; + cutlass::DeviceAllocation ptr_O_accumulate; + + +public: + + // + // Methods + // + + TestbedAttention( + Options &options_, + cutlass::Distribution::Kind init_Q_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_K_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_P_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_V_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_O_ = cutlass::Distribution::Uniform, + uint32_t seed_ = 3080 + ): + options(options_), init_Q(init_Q_), init_K(init_K_), init_P(init_P_), init_V(init_V_), init_O(init_O_), seed(seed_) { } + + int problem_count() const { + return (options.head_number * options.batch_size); + } + +private: + + /// Helper to initialize a tensor view + template + void initialize_tensor_( + Element *ptr, + size_t capacity, + cutlass::Distribution::Kind dist_kind, + uint32_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + Element scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + scope_max = 8; + scope_min = -8; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::device::BlockFillRandomUniform( + ptr, capacity, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::device::BlockFillRandomGaussian( + ptr, capacity, seed, Element(), Element(0.5f)); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + // Fill with increasing elements + cutlass::reference::device::BlockFillSequential( + ptr, capacity, Element(1), Element()); + } + else { + + // Fill with all 1s + cutlass::reference::device::BlockFillSequential( + ptr, capacity, Element(), Element(1)); + } + } + + /// Initializes data structures + void initialize_() { + + // + // Set scalors for the mha example + // + + options.alpha0 = 1.0f / sqrt(float(options.head_size)); + options.alpha1 = 1.0f; + options.beta = 0; + + // + // Choose random problem sizes + // + + // construct a few problems of random sizes + srand(seed); + + int64_t total_elements_Q = 0; + int64_t total_elements_K = 0; + int64_t total_elements_P = 0; + int64_t total_elements_V = 0; + int64_t total_elements_O = 0; + + ldq_host.resize(problem_count()); + ldk_host.resize(problem_count()); + ldp_host.resize(problem_count()); + ldv_host.resize(problem_count()); + ldo_host.resize(problem_count()); + seqlen_host.resize(problem_count()); + + for (int32_t i = 0; i < problem_count(); ++i) { + + auto problem0 = options.problem_sizes0.at(i); + auto problem1 = options.problem_sizes1.at(i); + + ldq_host.at(i) = LayoutQ::packed({problem0.m(), problem0.k()}).stride(0); + ldk_host.at(i) = LayoutK::packed({problem0.k(), problem0.n()}).stride(0); + ldp_host.at(i) = LayoutP::packed({problem0.m(), problem0.n()}).stride(0); + ldv_host.at(i) = LayoutV::packed({problem1.k(), problem1.n()}).stride(0); + ldo_host.at(i) = LayoutO::packed({problem1.m(), problem1.n()}).stride(0); + + // m = n for attention problems. + seqlen_host.at(i) = problem0.m(); + + offset_Q.push_back(total_elements_Q); + offset_K.push_back(total_elements_K); + offset_P.push_back(total_elements_P); + offset_V.push_back(total_elements_V); + offset_O.push_back(total_elements_O); + + int64_t elements_Q = problem0.m() * problem0.k(); + int64_t elements_K = problem0.k() * problem0.n(); + int64_t elements_P = problem0.m() * problem0.n(); + int64_t elements_V = problem1.k() * problem1.n(); + int64_t elements_O = problem1.m() * problem1.n(); + + total_elements_Q += elements_Q; + total_elements_K += elements_K; + total_elements_P += elements_P; + total_elements_V += elements_V; + total_elements_O += elements_O; + + } + + problem_sizes_device0.reset(problem_count()); + problem_sizes_device1.reset(problem_count()); + problem_sizes_device0.copy_from_host(options.problem_sizes0.data()); + problem_sizes_device1.copy_from_host(options.problem_sizes1.data()); + + if (options.use_mask) { + problem_sizes_device0_real.reset(problem_count()); + problem_sizes_device0_real.copy_from_host(options.problem_sizes0_real.data()); + } + + ldq.reset(problem_count()); + ldk.reset(problem_count()); + ldp.reset(problem_count()); + ldv.reset(problem_count()); + ldo.reset(problem_count()); + seqlen.reset(problem_count()); + + ldq.copy_from_host(ldq_host.data()); + ldk.copy_from_host(ldk_host.data()); + ldp.copy_from_host(ldp_host.data()); + ldv.copy_from_host(ldv_host.data()); + ldo.copy_from_host(ldo_host.data()); + seqlen.copy_from_host(seqlen_host.data()); + + // + // Assign pointers + // + + block_Q.reset(total_elements_Q); + block_K.reset(total_elements_K); + block_P.reset(total_elements_P); + block_V.reset(total_elements_V); + block_O.reset(total_elements_O); + + if (kNeedsOutputAccumulatorBuffer) { + block_O_accumulate.reset(total_elements_O); + } + + offset_P_Device.reset(problem_count()); + + // sync offset with device + cutlass::device_memory::copy_to_device(offset_P_Device.get(), offset_P.data(), offset_P.size()); + + std::vector ptr_Q_host(problem_count()); + std::vector ptr_K_host(problem_count()); + std::vector ptr_P_host(problem_count()); + std::vector ptr_V_host(problem_count()); + std::vector ptr_O_host(problem_count()); + std::vector ptr_O_accumulate_host(problem_count()); + std::vector ptr_norm_host(problem_count()); + std::vector ptr_sum_host(problem_count()); + + for (int32_t i = 0; i < problem_count(); ++i) { + ptr_Q_host.at(i) = block_Q.get() + offset_Q.at(i); + ptr_K_host.at(i) = block_K.get() + offset_K.at(i); + ptr_P_host.at(i) = block_P.get() + offset_P.at(i); + ptr_V_host.at(i) = block_V.get() + offset_V.at(i); + ptr_O_host.at(i) = block_O.get() + offset_O.at(i); + + if (kNeedsOutputAccumulatorBuffer) { + ptr_O_accumulate_host.at(i) = block_O_accumulate.get() + offset_O.at(i); + } + } + + ptr_Q.reset(problem_count()); + ptr_Q.copy_from_host(ptr_Q_host.data()); + + ptr_K.reset(problem_count()); + ptr_K.copy_from_host(ptr_K_host.data()); + + ptr_P.reset(problem_count()); + ptr_P.copy_from_host(ptr_P_host.data()); + + ptr_V.reset(problem_count()); + ptr_V.copy_from_host(ptr_V_host.data()); + + ptr_O.reset(problem_count()); + ptr_O.copy_from_host(ptr_O_host.data()); + + if (kNeedsOutputAccumulatorBuffer) { + ptr_O_accumulate.reset(problem_count()); + ptr_O_accumulate.copy_from_host(ptr_O_accumulate_host.data()); + } + + // + // Initialize the problems of the workspace + // + + initialize_tensor_(block_Q.get(), total_elements_Q, init_Q, seed + 1); + initialize_tensor_(block_K.get(), total_elements_K, init_K, seed + 2); + initialize_tensor_(block_V.get(), total_elements_V, init_V, seed + 3); + + } + + template + bool verify_tensor_(std::vector vector_Input, \ + std::vector vector_Input_Ref, + int64_t verify_length = -1) { + + int64_t size = (vector_Input.size() < vector_Input_Ref.size()) ? vector_Input.size() : vector_Input_Ref.size(); + size = (verify_length == -1) ? size : verify_length; + + // 0.05 for absolute error + float abs_tol = 5e-2f; + // 10% for relative error + float rel_tol = 1e-1f; + for (int64_t i = 0; i < size; ++i) { + float diff = (float)(vector_Input.at(i) - vector_Input_Ref.at(i)); + float abs_diff = fabs(diff); + float abs_ref = fabs((float)vector_Input_Ref.at(i) + 1e-5f); + float relative_diff = abs_diff / abs_ref; + if ( (isnan(abs_diff) || isinf(abs_diff)) || (abs_diff > abs_tol && relative_diff > rel_tol)) { + printf("[%d/%d] diff = %f, rel_diff = %f, {computed=%f, ref=%f}.\n", int(i), int(size), abs_diff, relative_diff, (float)(vector_Input.at(i)), (float)(vector_Input_Ref.at(i))); + return false; + } + + } + + return true; + } + + /// Verifies the result is a GEMM + bool verify_() { + + bool passed = true; + + for (int32_t i = 0; i < problem_count(); ++i) { + cutlass::gemm::GemmCoord problem0 = options.problem_sizes0.at(i); + cutlass::gemm::GemmCoord problem1 = options.problem_sizes1.at(i); + + LayoutQ layout_Q(ldq_host.at(i)); + LayoutK layout_K(ldk_host.at(i)); + LayoutP layout_P(ldp_host.at(i)); + LayoutV layout_V(ldv_host.at(i)); + LayoutO layout_O(ldo_host.at(i)); + + MatrixCoord extent_Q{problem0.m(), problem0.k()}; + MatrixCoord extent_K{problem0.k(), problem0.n()}; + MatrixCoord extent_P{problem0.m(), problem0.n()}; + MatrixCoord extent_V{problem1.k(), problem1.n()}; + MatrixCoord extent_O{problem1.m(), problem1.n()}; + + cutlass::TensorView view_Q(block_Q.get() + offset_Q.at(i), layout_Q, extent_Q); + cutlass::TensorView view_K(block_K.get() + offset_K.at(i), layout_K, extent_K); + cutlass::TensorView view_P(block_P.get() + offset_P.at(i), layout_P, extent_P); + cutlass::TensorView view_V(block_V.get() + offset_V.at(i), layout_V, extent_V); + + cutlass::DeviceAllocation block_Ref(layout_P.capacity(extent_P)); + cutlass::TensorView view_Ref_device(block_Ref.get(), layout_P, extent_P); + + cutlass::DeviceAllocation block_Ref_O(layout_O.capacity(extent_O)); + cutlass::TensorView view_Ref_O_device(block_Ref_O.get(), layout_O, extent_O); + cutlass::reference::device::TensorFill(view_Ref_O_device, ElementO(0)); + + // Reference GEMM + cutlass::reference::device::GemmComplex< + ElementQ, LayoutQ, + ElementK, LayoutK, + ElementP, LayoutP, + ElementCompute, ElementAccumulator + >( + problem0, + ElementAccumulator(options.alpha0), + view_Q, + Attention::GemmKernel::MM0::Mma::kTransformA, + view_K, + Attention::GemmKernel::MM0::Mma::kTransformB, + ElementAccumulator(options.beta), + view_P, + view_Ref_device, + ElementAccumulator(0) + ); + + // Compute softmax for P. We need to explicitly compute softmax + // over P because softmax is fused to the second GEMM in the + // profiled implementation. + std::vector matrix_Ref(layout_P.capacity(extent_P)); + cutlass::device_memory::copy_to_host(matrix_Ref.data(), block_Ref.get(), matrix_Ref.size()); + cutlass::TensorView view_Ref_host(matrix_Ref.data(), layout_P, extent_P); + std::vector vector_Norm_Ref(problem0.m()); + std::vector vector_Sum_Ref(problem0.m()); + + int n_dim = options.use_mask ? options.problem_sizes0_real.at(i).n() : problem0.n(); + + // Compute softmax for reference matrix + for (int m = 0; m < problem0.m(); m++) { + int n_dim_row = n_dim; + if (options.causal) { + n_dim_row = std::min(m + 1, n_dim); + } + ElementSoftmaxCompute max = ElementSoftmaxCompute(view_Ref_host.ref().at({m, 0})); + for (int n = 1; n < n_dim_row; n++) { + max = std::max(max, ElementSoftmaxCompute(view_Ref_host.ref().at({m, n}))); + } + + vector_Norm_Ref.at(m) = ElementNorm(max); + + ElementSoftmaxCompute sum = ElementSoftmaxCompute(); + for (int n = 0; n < n_dim_row; n++) { + sum += std::exp( ElementSoftmaxCompute(view_Ref_host.ref().at({m, n})) - max ); + } + ElementSoftmaxCompute inv_sum = ElementSoftmaxCompute(1.0f / sum); + + vector_Sum_Ref.at(m) = ElementSum(inv_sum); + + for (int n = 0; n < n_dim_row; n++) { + view_Ref_host.ref().at({m, n}) = ElementP( + std::exp( ElementSoftmaxCompute(view_Ref_host.ref().at({m, n})) - max ) * inv_sum + ); + } + // Mask out the rest of the attention matrix + for (int n = n_dim_row; n < n_dim; ++n) { + view_Ref_host.ref().at({m, n}) = ElementP(0); + } + + } + + // when not using mask, problem_real and problem share the same sizes + if (options.use_mask) { + for (int m = 0; m < problem0.m(); m++) { + for (int n = n_dim; n < problem0.n(); n++) { + view_Ref_host.ref().at({m, n}) = ElementP(0); + } + } + } + + cutlass::device_memory::copy_to_device(block_P.get() + offset_P.at(i), matrix_Ref.data(), matrix_Ref.size()); + + // Reference GEMM + cutlass::reference::device::GemmComplex< + ElementP, LayoutP, + ElementV, LayoutV, + ElementO, LayoutO, + ElementCompute, ElementAccumulator + >( + problem1, + ElementAccumulator(options.alpha1), + view_P, + Attention::GemmKernel::MM0::Mma::kTransformA, + view_V, + Attention::GemmKernel::MM0::Mma::kTransformB, + ElementAccumulator(options.beta), + view_Ref_O_device, + view_Ref_O_device, + ElementAccumulator(0) + ); + + // Copy to host memory + cutlass::TensorView view_Ref(matrix_Ref.data(), layout_P, extent_P); + + std::vector matrix_O(layout_O.capacity(extent_O)); + cutlass::device_memory::copy_to_host(matrix_O.data(), block_O.get() + offset_O.at(i), matrix_O.size()); + std::vector matrix_Ref_O(layout_O.capacity(extent_O)); + cutlass::device_memory::copy_to_host(matrix_Ref_O.data(), block_Ref_O.get(), matrix_Ref_O.size()); + + + bool verified_O = false; + if (!verified_O) { + verified_O = verify_tensor_(matrix_O, matrix_Ref_O); + } + + passed = passed && verified_O; + + if (!passed) { + std::cerr << "\n***\nError - problem " << i << " failed the QA check\n***\n" << std::endl; + + if (!verified_O) { + std::cout << "Final matrix output is incorrect" << std::endl; + } + + return passed; + } + + } + + return passed; + } + +public: + + + /// Executes a CUTLASS Attention kernel and measures runtime. + Result profile() { + + Result result; + result.passed = false; + + int threadblock_count = Attention::sufficient(options.problem_sizes1.data(), options.problem_count); + + // Early exit + if (!threadblock_count) { + std::cout << "Active CUDA device lacks hardware resources to run CUTLASS Grouped FMHA kernel." << std::endl; + return result; + } + + result.passed = false; + + // Initialize the problem + initialize_(); + + typename Attention::Arguments args( + problem_sizes_device0.get(), + problem_sizes_device1.get(), + options.problem_count, + threadblock_count, + ptr_Q.get(), + ptr_K.get(), + ptr_P.get(), + ptr_V.get(), + ptr_O.get(), + ptr_O_accumulate.get(), + ldq.get(), + ldk.get(), + ldp.get(), + ldv.get(), + ldo.get(), + options.causal, + options.alpha0, + options.problem_sizes1.data() + ); + + Attention fmha; + + size_t workspace_size = fmha.get_workspace_size(args); + cutlass::DeviceAllocation workspace(workspace_size); + + result.status = fmha.initialize(args, workspace.get()); + + if (result.status != cutlass::Status::kSuccess) { + std::cerr << "Failed to initialize CUTLASS Grouped FMHA kernel." << std::endl; + return result; + } + + // Run the grouped FMHA object + result.status = fmha.run(); + + if (result.status != cutlass::Status::kSuccess) { + std::cerr << "Failed to run CUTLASS Grouped FMHA kernel." << std::endl; + return result; + } + + // Wait for completion + result.error = cudaDeviceSynchronize(); + + if (result.error != cudaSuccess) { + std::cerr << "Kernel execution error: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // + // Verify correctness + // + result.passed = true; + + if (options.reference_check) { + result.passed = verify_(); + } + + // + // Warm-up run of the grouped FMHA object + // + result.status = fmha.run(); + + if (result.status != cutlass::Status::kSuccess) { + std::cerr << "Failed to run CUTLASS Grouped FMHA kernel." << std::endl; + return result; + } + + // + // Construct events + // + + cudaEvent_t events[2]; + + for (auto & event : events) { + result.error = cudaEventCreate(&event); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; + return -1; + } + } + + // Record an event at the start of a series of FMHA operations + result.error = cudaEventRecord(events[0]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // + // Run profiling loop + // + + for (int iter = 0; iter < this->options.iterations; ++iter) { + fmha(); + } + + // + // Stop profiling loop + // + + // Record an event when the GEMM operations have been launched. + result.error = cudaEventRecord(events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Wait for work on the device to complete. + result.error = cudaEventSynchronize(events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Measure elapsed runtime + float runtime_ms = 0; + result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Compute average runtime and GFLOPs. + result.runtime_ms = double(runtime_ms) / double(this->options.iterations); + result.gflops = this->options.gflops(result.runtime_ms / 1000.0); + + // + // Cleanup + // + + for (auto event : events) { + (void)cudaEventDestroy(event); + } + + std::cout << std::endl; + std::cout << "CUTLASS Attention:\n" + << "====================================================" << std::endl; + std::cout << " " << " {seq length Q, seq length KV, head size, head size V, head number, batch size} = {" << options.seq_length \ + << ", " << options.seq_length_kv << ", " << options.head_size << ", " << options.head_size_v << ", " << options.head_number\ + << ", " << options.batch_size << "}." << std::endl; + options.print_problems(); + std::cout << std::endl; + std::cout << " " << "Runtime: " << result.runtime_ms << " ms" << std::endl; + std::cout << " " << "GFLOPs: " << result.gflops << std::endl; + + return result; + } + + +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + int kQueriesPerBlock, + int kKeysPerBlock, + int kMaxK, + cutlass::gemm::kernel::GroupScheduleMode GroupScheduleMode_ +> +int run_grouped(Options& options) { + using AttentionKernel = typename cutlass::gemm::kernel::DefaultFMHAGrouped< + cutlass::half_t, // scalar_t + cutlass::arch::Sm80, // ArchTag + true, // Memory is aligned + kQueriesPerBlock, + kKeysPerBlock, + kMaxK, + GroupScheduleMode_ + >::FMHAKernel; + + using FMHA = cutlass::gemm::device::GemmGrouped; + + // + // Test and profile + // + + TestbedAttention testbed(options); + + Result result = testbed.profile(); + if (!result.passed) { + std::cout << "Profiling CUTLASS attention has failed.\n"; + std::cout << "\nFailed\n"; + return -1; + } + + std::cout << "\nPassed\n"; + return 0; +} + + +template < + int kQueriesPerBlock, + int kKeysPerBlock, + int kMaxK +> +int run_attention(Options& options) { + if (options.scheduler_mode == cutlass::gemm::kernel::GroupScheduleMode::kDeviceOnly) { + return run_grouped(options); + } else { + return run_grouped(options); + } +} + + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // + // This example uses mma.sync to directly access Tensor Cores to achieve peak performance. + // + + cudaDeviceProp props; + + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if (__CUDACC_VER_MAJOR__ < 11 || props.major < 8) { + + // + // This example requires an NVIDIA Ampere-architecture GPU. + // + + std::cout + << "CUTLASS's CUTLASS Attention example requires a GPU of NVIDIA's Ampere Architecture or " + << "later (compute capability 80 or greater).\n"; + + return 0; + } + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + + if (options.use_mask) { + std::cerr << "--use_mask is not supported at the moment\n"; + return -2; + } + if (options.alignment != 1) { + std::cerr << "--alignment=1 is the only supported value\n"; + return -2; + } + + // Determine kernel configuration based on head size. + // If head size is less than or equal to 64, each block operates over 64 queries and + // 64 keys, and partial results can be stored in the register file. + // If head size is greater than 64, each block operates over 32 queries and 128 keys, + // and partial results are stored in shared memory. + if (options.head_size_v > 64) { + static int const kQueriesPerBlock = 32; + static int const kKeysPerBlock = 128; + if (options.head_size_v <= kKeysPerBlock) { + return run_attention(options); + } else { + return run_attention(options); + } + } else { + static constexpr int kMaxK = 64; // <- Decrease to 32/16 if your problem is smaller + static int const kQueriesPerBlock = 64; + static int const kKeysPerBlock = 64; + return run_attention(options); + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/41_fused_multi_head_attention/gemm/custom_mma.h b/examples/41_fused_multi_head_attention/gemm/custom_mma.h new file mode 100644 index 0000000000..80f5d4ea11 --- /dev/null +++ b/examples/41_fused_multi_head_attention/gemm/custom_mma.h @@ -0,0 +1,124 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "custom_mma_multistage.h" +#include "custom_mma_pipelined.h" +#include "cutlass/gemm/threadblock/mma_multistage.h" +#include "cutlass/gemm/threadblock/mma_pipelined.h" + +template +struct MakeCustomMma; + +template < + typename Shape, + typename IteratorA, + typename SmemIteratorA, + cutlass::arch::CacheOperation::Kind CacheOpA, + typename IteratorB, + typename SmemIteratorB, + cutlass::arch::CacheOperation::Kind CacheOpB, + typename ElementC, + typename LayoutC, + typename Policy, + int Stages, + cutlass::gemm::SharedMemoryClearOption SharedMemoryClear, + int kMaxK> +struct MakeCustomMma< + cutlass::gemm::threadblock::MmaMultistage< + Shape, + IteratorA, + SmemIteratorA, + CacheOpA, + IteratorB, + SmemIteratorB, + CacheOpB, + ElementC, + LayoutC, + Policy, + Stages, + SharedMemoryClear>, + kMaxK> { + // Reduce the number of stages if we don't need that many + static int constexpr kStages = + kMaxK == cutlass::platform::numeric_limits::max() + ? Stages + : cutlass::const_min( + Stages, + (kMaxK + int(Shape::kK) - 1) / int(Shape::kK)); + using Mma = cutlass::gemm::threadblock::CustomMmaMultistage< + Shape, + IteratorA, + SmemIteratorA, + CacheOpA, + IteratorB, + SmemIteratorB, + CacheOpB, + ElementC, + LayoutC, + Policy, + kStages, + SharedMemoryClear, + kMaxK>; +}; + +template < + typename Shape, + typename IteratorA, + typename SmemIteratorA, + typename IteratorB, + typename SmemIteratorB, + typename ElementC, + typename LayoutC, + typename Policy, + int kMaxK> +struct MakeCustomMma< + cutlass::gemm::threadblock::MmaPipelined< + Shape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + Policy>, + kMaxK> { + using Mma = cutlass::gemm::threadblock::CustomMmaPipelined< + Shape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + Policy>; +}; diff --git a/examples/41_fused_multi_head_attention/gemm/custom_mma_base.h b/examples/41_fused_multi_head_attention/gemm/custom_mma_base.h new file mode 100644 index 0000000000..be25f79c4e --- /dev/null +++ b/examples/41_fused_multi_head_attention/gemm/custom_mma_base.h @@ -0,0 +1,182 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/mma_base.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class CustomMmaBase { + public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + + ///< Policy describing tuning details + using Policy = Policy_; + + // + // Dependent types + // + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Shape describing the overall GEMM computed from shared memory + /// by each warp. + using WarpGemm = typename Policy::Operator::Shape; + + /// Shape describing the number of warps filling the CTA + using WarpCount = GemmShape< + Shape::kM / WarpGemm::kM, + Shape::kN / WarpGemm::kN, + Shape::kK / WarpGemm::kK>; + + /// Number of warp-level GEMM oeprations + static int const kWarpGemmIterations = + (WarpGemm::kK / Operator::Policy::MmaShape::kK); + + /// Number of stages + static int const kStages = Stages; + + // + // Nested structs + // + + /// Shared storage object needed by threadblock-scoped GEMM + template + struct OperandSharedStorage { + AlignedBuffer buffer; + using TensorRef = TensorRef; + + CUTLASS_DEVICE + static OperandLayout Layout() { + return OperandLayout::packed({OperandShape::kRow, OperandShape::kColumn}); + } + + /// Returns a TensorRef to the operand + CUTLASS_HOST_DEVICE + TensorRef ref() { + return TensorRef{buffer.data(), Layout()}; + } + }; + + /// Shape of the A matrix operand in shared memory + using ShapeA = MatrixShape< + Shape::kM + Policy::SmemPaddingA::kRow, + Shape::kK * kStages + Policy::SmemPaddingA::kColumn>; + + /// Shape of the B matrix operand in shared memory + using ShapeB = MatrixShape< + Shape::kK * kStages + Policy::SmemPaddingB::kRow, + Shape::kN + Policy::SmemPaddingB::kColumn>; + + using SharedStorageA = OperandSharedStorage< + typename Operator::ElementA, + ShapeA, + typename Operator::LayoutA>; + using SharedStorageB = OperandSharedStorage< + typename Operator::ElementB, + ShapeB, + typename Operator::LayoutB>; + using TensorRefA = typename SharedStorageA::TensorRef; + using TensorRefB = typename SharedStorageB::TensorRef; + + struct SharedStorage { + /// Buffer for A operand + SharedStorageA operand_A; + + /// Buffer for B operand + SharedStorageB operand_B; + }; + + protected: + // + // Data members + // + + /// Iterator to load a warp-scoped tile of A operand from shared memory + typename Operator::IteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of B operand from shared memory + typename Operator::IteratorB warp_tile_iterator_B_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + CustomMmaBase( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + SharedStorageA& shared_storageA, + SharedStorageB& shared_storageB, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : warp_tile_iterator_A_(shared_storageA.ref(), lane_idx), + warp_tile_iterator_B_(shared_storageB.ref(), lane_idx) {} +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/41_fused_multi_head_attention/gemm/custom_mma_multistage.h b/examples/41_fused_multi_head_attention/gemm/custom_mma_multistage.h new file mode 100644 index 0000000000..eedcb6376b --- /dev/null +++ b/examples/41_fused_multi_head_attention/gemm/custom_mma_multistage.h @@ -0,0 +1,760 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/cache_operation.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "custom_mma_base.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, + /// Upper boundon the K dimension + int kMaxK = cutlass::platform::numeric_limits::max(), + /// Used for partial specialization + typename Enable = bool> +class CustomMmaMultistage : public CustomMmaBase { + public: + ///< Base class + using Base = CustomMmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + ///< Policy describing tuning details + using Policy = Policy_; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + // + // Dependent types + // + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + /// Internal structure exposed for introspection. + struct Detail { + static_assert( + Base::kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = + IteratorA::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB = + IteratorB::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA = + (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / + Base::kWarpGemmIterations; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB = + (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / + Base::kWarpGemmIterations; + }; + + static bool const kSmemContainsEntireMat = kMaxK <= Shape::kK * Stages; + static constexpr int kNumStagesConcurrentLoad = + kSmemContainsEntireMat ? Stages : Stages - 1; + + private: + using WarpLoadedFragmentA = typename Operator::FragmentA; + using WarpLoadedFragmentB = typename Operator::FragmentB; + using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; + using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; + + private: + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + bool prologue_done_; + + // Set to `True` to ensure the accumulator will be zero outside the GEMM + // footprint + bool zero_outside_bounds_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + CustomMmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorageA& shared_storageA, + typename Base::SharedStorageB& shared_storageB, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : Base(shared_storageA, shared_storageB, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storageA.ref(), thread_idx), + smem_iterator_B_(shared_storageB.ref(), thread_idx), + prologue_done_(false), + zero_outside_bounds_(false) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + CUTLASS_DEVICE + CustomMmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage& st, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : CustomMmaMultistage( + st.operand_A, + st.operand_B, + thread_idx, + warp_idx, + lane_idx) {} + + CUTLASS_DEVICE + bool set_prologue_done(bool value) { + prologue_done_ = value; + return true; + } + + CUTLASS_DEVICE + bool set_zero_outside_bounds(bool value) { + zero_outside_bounds_ = value; + return true; + } + + template + CUTLASS_DEVICE static void prologue( + typename Base::SharedStorage& shared_storage, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + int thread_idx, + int problem_size_k) { + prologue( + shared_storage.operand_A, + shared_storage.operand_B, + iterator_A, + iterator_B, + thread_idx, + problem_size_k); + } + + template + CUTLASS_DEVICE static void prologue( + typename Base::SharedStorageA& shared_storageA, + typename Base::SharedStorageB& shared_storageB, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + int thread_idx, + int problem_size_k) { + SmemIteratorA smem_iterator_A(shared_storageA.ref(), thread_idx); + SmemIteratorB smem_iterator_B(shared_storageB.ref(), thread_idx); + int32_t iter = (problem_size_k + Base::Shape::kK - 1) / Base::Shape::kK; + _prologue( + iterator_A, iterator_B, iter, smem_iterator_A, smem_iterator_B); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance( + IteratorA& iterator_A, + IteratorB& iterator_B, + int group_start_A = 0, + int group_start_B = 0) { + iterator_A.set_iteration_index( + group_start_A * IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { + if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_A.get(); + + if (zero_outside_bounds_ || + SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + } + + iterator_B.set_iteration_index( + group_start_B * IteratorB::kAccessesPerVector); + this->smem_iterator_B_.set_iteration_index(group_start_B); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_B_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B.get(); + + if (zero_outside_bounds_ || + SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_B.valid()); + } + + ++iterator_B; + } + ++this->smem_iterator_B_; + } + } + } + + template + CUTLASS_DEVICE static void _prologue( + IteratorA& iterator_A, + IteratorB& iterator_B, + int32_t& gemm_k_iterations, + SmemIteratorA& smem_iterator_A_, + SmemIteratorB& smem_iterator_B_) { + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < kNumStagesConcurrentLoad; + ++stage, --gemm_k_iterations) { + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + + iterator_A.set_iteration_index(0); + smem_iterator_A_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast( + smem_iterator_A_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); + + if (kLoadA) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + } + + ++iterator_A; + } + + ++smem_iterator_A_; + } + + iterator_B.set_iteration_index(0); + smem_iterator_B_.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast( + smem_iterator_B_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + if (kLoadB) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B.get(), iterator_B.valid()); + } + + ++iterator_B; + } + + ++smem_iterator_B_; + } + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + smem_iterator_A_.add_tile_offset({0, 1}); + smem_iterator_B_.add_tile_offset({1, 0}); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC& accum, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + ///< initial value of accumulator + FragmentC const& src_accum) { + // + // Prologue + // + + if (!prologue_done_) { + _prologue( + iterator_A, + iterator_B, + gemm_k_iterations, + smem_iterator_A_, + smem_iterator_B_); + } else if (!kSmemContainsEntireMat) { + _prologue( + iterator_A, + iterator_B, + gemm_k_iterations, + smem_iterator_A_, + smem_iterator_B_); + } else { + gemm_k_iterations -= kNumStagesConcurrentLoad; + } + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + // + // Clear the remaining tiles of SMEM. This is a functional requirement for + // some kernels so that all accumulator elements outside the GEMM footprint + // are zero. + // + + if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) { + /// Iterator to write threadblock-scoped tile of A operand to shared + /// memory + SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); + + typename IteratorA::AccessType zero_A; + zero_A.clear(); + + last_smem_iterator_A.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType* dst_ptr = + reinterpret_cast( + last_smem_iterator_A.get()); + + *dst_ptr = zero_A; + + ++last_smem_iterator_A; + } + + /// Iterator to write threadblock-scoped tile of B operand to shared + /// memory + SmemIteratorB last_smem_iterator_B(this->smem_iterator_B_); + typename IteratorB::AccessType zero_B; + + zero_B.clear(); + last_smem_iterator_B.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType* dst_ptr = + reinterpret_cast( + last_smem_iterator_B.get()); + + *dst_ptr = zero_B; + + ++last_smem_iterator_B; + } + } + + // Waits until kStages-2 stages have committed. + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpLoadedFragmentA warp_loaded_frag_A[2]; + WarpLoadedFragmentB warp_loaded_frag_B[2]; + WarpTransformedFragmentA warp_transformed_frag_A[2]; + WarpTransformedFragmentB warp_transformed_frag_B[2]; + + Operator warp_mma; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + warp_mma.transform( + warp_transformed_frag_A[0], + warp_transformed_frag_B[0], + warp_loaded_frag_A[0], + warp_loaded_frag_B[0]); + + // tf32x3 kernels use staging accumulation. warp_mma uses a temporary + // accumulator and this temporary accumulator is added to the final + // accumulator once in every mainloop iteration. + plus plus_accum; + + FragmentC tmp_accum; + + if (platform::is_same< + typename Operator::MathOperator, + arch::OpMultiplyAddFastF32>::value || + platform::is_same< + typename Operator::MathOperator, + arch::OpMultiplyAddComplexFastF32>::value) { + tmp_accum.clear(); + } + + // + // Mainloop + // + + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-kNumStagesConcurrentLoad);) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + this->warp_tile_iterator_A_.set_kgroup_index( + (warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_mma_k + 1) % Base::kWarpGemmIterations); + + // In case of a non-circular buffer ("kSmemContainsEntireMat") + // make sure we don't load out of bounds data. + if (!kSmemContainsEntireMat || + gemm_k_iterations > (-kNumStagesConcurrentLoad) || + warp_mma_k < Base::kWarpGemmIterations - 1) { + this->warp_tile_iterator_A_.load( + warp_loaded_frag_A[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B_.load( + warp_loaded_frag_B[(warp_mma_k + 1) % 2]); + } + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + if (warp_mma_k > 0) + warp_mma.transform( + warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B[warp_mma_k % 2], + warp_loaded_frag_A[warp_mma_k % 2], + warp_loaded_frag_B[warp_mma_k % 2]); + + if (platform::is_same< + typename Operator::MathOperator, + arch::OpMultiplyAddFastF32>::value || + platform::is_same< + typename Operator::MathOperator, + arch::OpMultiplyAddComplexFastF32>::value) { + warp_mma( + tmp_accum, + warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B[warp_mma_k % 2], + tmp_accum); + + if (warp_mma_k == 0) { + accum = plus_accum(accum, tmp_accum); + tmp_accum.clear(); + } + } else { + warp_mma( + accum, + warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B[warp_mma_k % 2], + accum); + } + + // Issue global->shared copies for the this stage + if (!kSmemContainsEntireMat && + warp_mma_k < Base::kWarpGemmIterations - 1) { + int group_start_iteration_A, group_start_iteration_B; + + group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; + group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance( + iterator_A, + iterator_B, + group_start_iteration_A, + group_start_iteration_B); + } + + if (warp_mma_k + 2 == Base::kWarpGemmIterations) { + if (!kSmemContainsEntireMat) { + int group_start_iteration_A, group_start_iteration_B; + group_start_iteration_A = + (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + group_start_iteration_B = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance( + iterator_A, + iterator_B, + group_start_iteration_A, + group_start_iteration_B); + } + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Waits until kStages-2 stages have committed. + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (!kSmemContainsEntireMat && + smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_A_.add_tile_offset( + {0, + -Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations, + 0}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + + --gemm_k_iterations; + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B.clear_mask(gemm_k_iterations == 0); + } + + // Do any conversions feeding the first stage at the end of the loop so + // we can start right away on mma instructions + if (warp_mma_k + 1 == Base::kWarpGemmIterations) + warp_mma.transform( + warp_transformed_frag_A[(warp_mma_k + 1) % 2], + warp_transformed_frag_B[(warp_mma_k + 1) % 2], + warp_loaded_frag_A[(warp_mma_k + 1) % 2], + warp_loaded_frag_B[(warp_mma_k + 1) % 2]); + } + } + + if (platform::is_same< + typename Operator::MathOperator, + arch::OpMultiplyAddFastF32>::value || + platform::is_same< + typename Operator::MathOperator, + arch::OpMultiplyAddComplexFastF32>::value) { + accum = plus_accum(accum, tmp_accum); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/41_fused_multi_head_attention/gemm/custom_mma_pipelined.h b/examples/41_fused_multi_head_attention/gemm/custom_mma_pipelined.h new file mode 100644 index 0000000000..fd527a17b6 --- /dev/null +++ b/examples/41_fused_multi_head_attention/gemm/custom_mma_pipelined.h @@ -0,0 +1,401 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "custom_mma_base.h" +#include "cutlass/gemm/gemm.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Transformation applied to A operand + typename TransformA_ = NumericArrayConverter< + typename SmemIteratorA_::Element, + typename IteratorA_::Element, + IteratorA_::Fragment::kElements>, + /// + /// Transformation applied to B operand + typename TransformB_ = NumericArrayConverter< + typename SmemIteratorB_::Element, + typename IteratorB_::Element, + IteratorB_::Fragment::kElements>, + /// Used for partial specialization + typename Enable = bool> +class CustomMmaPipelined : public CustomMmaBase { + public: + ///< Base class + using Base = CustomMmaBase; + + using Shape = + Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using IteratorA = + IteratorA_; ///< Iterates over tiles of A operand in global memory + using IteratorB = + IteratorB_; ///< Iterates over tiles of B operand in global memory + using ElementC = ElementC_; ///< Data type of accumulator matrix + using LayoutC = LayoutC_; ///< Layout of accumulator matrix + using Policy = Policy_; ///< Policy describing tuning details + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + + using TransformA = TransformA_; + using TransformB = TransformB_; + + // + // Dependent types + // + + /// Fragment of operand A loaded from global memory + using FragmentA = typename IteratorA::Fragment; + + /// Fragment of operand B loaded from global memory + using FragmentB = typename IteratorB::Fragment; + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Obtain the arch tag from the warp-level operator + using ArchTag = typename Policy::Operator::ArchTag; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + // staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline) + static_assert( + (Base::kStages == 2), + "MmaPipelined requires kStages set to value 2"); + + static bool const kSmemContainsEntireMat = false; + + private: + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + + protected: + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + CustomMmaPipelined( + typename Base::SharedStorageA& shared_storageA, + typename Base::SharedStorageB& shared_storageB, + int thread_idx, ///< ID within the threadblock + int warp_idx, ///< ID of warp + int lane_idx ///< ID of each thread within a warp + ) + : Base(shared_storageA, shared_storageB, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storageA.ref(), thread_idx), + smem_iterator_B_(shared_storageB.ref(), thread_idx) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + CUTLASS_DEVICE + CustomMmaPipelined( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage& st, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : CustomMmaPipelined( + st.operand_A, + st.operand_B, + thread_idx, + warp_idx, + lane_idx) {} + + CUTLASS_DEVICE + bool set_prologue_done(bool value) { + // NOT IMPLEMENTED FOR PIPELINED + } + + CUTLASS_DEVICE + bool set_zero_outside_bounds(bool value) { + // NOT NEEDED FOR PIPELINED + // shared memory will always be zero-filled + } + + template + CUTLASS_DEVICE static void prologue( + typename Base::SharedStorage& shared_storage, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + int thread_idx, + int problem_size_k) { + prologue( + shared_storage.operand_A, + shared_storage.operand_B, + iterator_A, + iterator_B, + thread_idx, + problem_size_k); + } + + template + CUTLASS_DEVICE static void prologue( + typename Base::SharedStorageA& shared_storageA, + typename Base::SharedStorageB& shared_storageB, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB iterator_B, + int thread_idx, + int problem_size_k) { + // NOT IMPLEMENTED FOR PIPELINED + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + int gemm_k_iterations, ///< number of iterations of the mainloop + FragmentC& accum, ///< destination accumulator tile + IteratorA iterator_A, ///< iterator over A operand in global memory + IteratorB iterator_B, ///< iterator over B operand in global memory + FragmentC const& src_accum, ///< source accumulator tile + TransformA transform_A = + TransformA(), ///< transformation applied to A fragment + TransformB transform_B = + TransformB()) { ///< transformation applied to B fragment + + // + // Prologue + // + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + FragmentA tb_frag_A; + FragmentB tb_frag_B; + + tb_frag_A.clear(); + tb_frag_B.clear(); + + // The last kblock is loaded in the prolog + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); + + ++iterator_A; + ++iterator_B; + + this->smem_iterator_A_.store(transform_A(tb_frag_A)); + this->smem_iterator_B_.store(transform_B(tb_frag_B)); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + Operator warp_mma; + + int smem_write_stage_idx = 1; + + // Avoid reading out of bounds + iterator_A.clear_mask(gemm_k_iterations <= 1); + iterator_B.clear_mask(gemm_k_iterations <= 1); + + // Issue loads during the first warp-level matrix multiply-add *AFTER* + // issuing shared memory loads (which have the tightest latency + // requirement). + + // + // Mainloop + // + + // Note: The main loop does not support Base::kWarpGemmIterations == 2. + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > 0; --gemm_k_iterations) { + // + // Loop over GEMM K dimension + // + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + if (warp_mma_k == Base::kWarpGemmIterations - 1) { + // Write fragments to shared memory + this->smem_iterator_A_.store(transform_A(tb_frag_A)); + + this->smem_iterator_B_.store(transform_B(tb_frag_B)); + + __syncthreads(); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == 1) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + } else { + this->warp_tile_iterator_A_.add_tile_offset( + {0, + -Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations, + 0}); + } + + smem_write_stage_idx ^= 1; + } + + this->warp_tile_iterator_A_.set_kgroup_index( + (warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_mma_k + 1) % Base::kWarpGemmIterations); + + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + if (warp_mma_k == 0) { + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); + + ++iterator_A; + ++iterator_B; + + // Avoid reading out of bounds if this was the last loop iteration + iterator_A.clear_mask(gemm_k_iterations <= 2); + iterator_B.clear_mask(gemm_k_iterations <= 2); + } + + warp_mma( + accum, + warp_frag_A[warp_mma_k % 2], + warp_frag_B[warp_mma_k % 2], + accum); + } + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/41_fused_multi_head_attention/gemm/find_default_mma.h b/examples/41_fused_multi_head_attention/gemm/find_default_mma.h new file mode 100644 index 0000000000..ee7d3d6027 --- /dev/null +++ b/examples/41_fused_multi_head_attention/gemm/find_default_mma.h @@ -0,0 +1,191 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Cutlass provides helper template functions to figure out the right + datastructures to instanciate to run a GEMM with various parameters (see + `cutlass/gemm/threadblock/default_mma.h`). However, due to template + instantiation priority rules, it will only create an MmaMultiStage with + kStages=3 (otherwise creates an MmePipelined - which is not compatible with + FastF32). kStages=3 uses too much shared memory and we want to use kStages=2, + so we just copy-pasted some code from `default_mma.h` and + `default_mma_core.h` files and wrapped this template to allow our usecase. + + This is really only for the FastF32 case - aka using TensorCores with fp32. +*/ + +#pragma once + +#include "cutlass/gemm/threadblock/default_mma.h" +#include "cutlass/gemm/threadblock/default_mma_core_simt.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { + +template < + /// Element type for A matrix operand + typename ElementA, + /// Layout type for A matrix operand + typename LayoutA, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + /// Element type for B matrix operand + typename ElementB, + /// Layout type for B matrix operand + typename LayoutB, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + /// Element type for internal accumulation + typename ElementAccumulator, + /// Layout type for C and D matrix operand + typename LayoutC, + /// Operator class tag + typename OperatorClass, + /// Tag indicating architecture to tune for + typename ArchTag, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + /// Number of stages used in the pipelined mainloop + int Stages, + /// Operation perfomed by GEMM + typename Operator, + typename Enable_ = void> +struct FindDefaultMma { + static constexpr bool AccumulatorsInRowMajor = false; + static constexpr SharedMemoryClearOption SharedMemoryClear = + SharedMemoryClearOption::kNone; + using DefaultMma = cutlass::gemm::threadblock::DefaultMma< + ElementA, + LayoutA, + kAlignmentA, + ElementB, + LayoutB, + kAlignmentB, + ElementAccumulator, + LayoutC, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + Stages, + Operator, + AccumulatorsInRowMajor, + SharedMemoryClear>; +}; + +/// Specialization for sm80 / FastF32 / multistage with kStages=2 +template < + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Access granularity of A matrix in units of elements + int kAlignmentA, + typename ElementB_, + /// Layout type for B matrix operand + typename LayoutB_, + /// Access granularity of B matrix in units of elements + int kAlignmentB, + typename ElementAccumulator, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape, + int kStages, + typename Operator> +struct FindDefaultMma< + ElementA_, + LayoutA_, + kAlignmentA, + ElementB_, + LayoutB_, + kAlignmentB, + ElementAccumulator, + layout::RowMajor, + arch::OpClassTensorOp, + arch::Sm80, + ThreadblockShape, + WarpShape, + InstructionShape, + kStages, + Operator, + typename cutlass::platform::enable_if<(kAlignmentA > 1)>::type> { + using LayoutC = layout::RowMajor; + using OperatorClass = arch::OpClassTensorOp; + using ArchTag = arch::Sm80; + + using DefaultMma_ = cutlass::gemm::threadblock::DefaultMma< + ElementA_, + LayoutA_, + kAlignmentA, + ElementB_, + LayoutB_, + kAlignmentB, + ElementAccumulator, + LayoutC, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + 3, + Operator>; + struct DefaultMma : DefaultMma_ { + using MmaCore_ = typename DefaultMma_::MmaCore; + // Define the threadblock-scoped multistage matrix multiply + using ThreadblockMma = cutlass::gemm::threadblock::MmaMultistage< + typename MmaCore_::Shape, + typename DefaultMma_::IteratorA, + typename MmaCore_::SmemIteratorA, + MmaCore_::kCacheOpA, + typename DefaultMma_::IteratorB, + typename MmaCore_::SmemIteratorB, + MmaCore_::kCacheOpB, + ElementAccumulator, + LayoutC, + typename MmaCore_::MmaPolicy, + kStages>; + }; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/examples/41_fused_multi_head_attention/gemm/mma_accum_lambda_iterator.h b/examples/41_fused_multi_head_attention/gemm/mma_accum_lambda_iterator.h new file mode 100644 index 0000000000..0a67c4e853 --- /dev/null +++ b/examples/41_fused_multi_head_attention/gemm/mma_accum_lambda_iterator.h @@ -0,0 +1,378 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/functional.h" +#include "cutlass/gemm/warp/mma_simt_tile_iterator.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm70.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_iterator_sm80.h" +#include "cutlass/matrix_shape.h" + +/* +TensorCores have different accumulator layouts. +This file provides a class to easily map the accumulator +i-th element with the corresponding matrix row/col. +*/ + +template +struct AccumLambdaIteratorSm80 { + static_assert( + cutlass::platform:: + is_same::value, + "only RowMajor is supported"); + + using Policy = typename T::Policy; + using InstructionShape = typename T::InstructionShape; + using OpDelta = typename T::OpDelta; + using Shape = typename T::Shape; + static int const kElementsPerAccess = InstructionShape::kN / 4; + static int const kRowsPerTile = 8; + static int const kAccumulatorRows = InstructionShape::kM / kRowsPerTile; + + static cutlass::MatrixCoord CUTLASS_DEVICE get_lane_offset( + int8_t lane_id, + int8_t warp_id, + typename T::TensorCoord const& tile_offset) { + int quad = (lane_id >> 2); + int lane_in_quad = (lane_id & 3); + return cutlass::MatrixCoord( + quad + tile_offset.row() * Shape::kRow, + lane_in_quad * kElementsPerAccess + + tile_offset.column() * Shape::kColumn); + } + + template + CUTLASS_DEVICE static void iterateRows( + cutlass::MatrixCoord& lane_offset, + FA beginRow, + FB op, + FC endRow) { + // See cutlass/gemm/warp/mma_tensor_op_tile_iterator.h + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < kAccumulatorRows; ++row) { + int accum_m = mma_m * InstructionShape::kM * OpDelta::kRow + + row * kRowsPerTile + lane_offset.row(); + beginRow(accum_m); + + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { + int mma_accum_start = kAccumulatorRows * kElementsPerAccess * + (mma_n * Policy::MmaIterations::kRow + mma_m); + CUTLASS_PRAGMA_UNROLL + for (int col = 0; col < kElementsPerAccess; ++col) { + int accum_n = mma_n * InstructionShape::kN * OpDelta::kColumn + + col + lane_offset.column(); + int idx = mma_accum_start + row * kElementsPerAccess + col; + op(accum_m, accum_n, idx); + } + } + + endRow(accum_m); + } + } + } + + template + CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn) { + // In each warp, 4 threads will work on the same row + // - the ones with the same `quad` + auto otherV = __shfl_xor_sync(0xffffffff, myValue, 1); + myValue = fn(myValue, otherV); + otherV = __shfl_xor_sync(0xffffffff, myValue, 2); + myValue = fn(myValue, otherV); + int lane_in_quad = (lane_id & 3); + return lane_in_quad == 0; + } +}; + +template +struct AccumLambdaIteratorSm70 { + static_assert( + cutlass::platform:: + is_same::value, + "only RowMajor is supported"); + + using Policy = typename T::Policy; + using InstructionShape = typename T::InstructionShape; + using OpDelta = typename T::OpDelta; + using Shape = typename T::Shape; + using Element = accum_t; + + static int const kElementsPerPartial = 4; + using EleShapePerPatial = typename cutlass::platform::conditional< + cutlass::platform::is_same::value, + cutlass::MatrixShape<2, 2>, + cutlass::MatrixShape<1, 4>>::type; + static int const kElementsPerMma = 8; + static int const kAccumulatorPatials = 2; + using QuadShapePerPatialMma = cutlass::MatrixShape<4, 4>; + + static cutlass::MatrixCoord CUTLASS_DEVICE get_lane_offset( + int8_t lane_id, + int8_t warp_id, + typename T::TensorCoord const& tile_offset) { + int quad = (lane_id >> 2); + int lane_in_quad = (lane_id & 3); + int accum_m, accum_n; + + if (cutlass::platform::is_same::value) { + // (quad[2],quad[0])+lane_in_quad[0] + accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + (lane_in_quad & 1); + // (quad[1])+lane_in_quad[1] + accum_n = + ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials + + (lane_in_quad & 2); + } else { + accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + + lane_in_quad; // (quad[2],quad[0]) + accum_n = ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials; + } + return cutlass::MatrixCoord( + accum_m + tile_offset.row() * Shape::kRow, + accum_n + tile_offset.column() * Shape::kColumn); + } + + template + CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn) { + static_assert( + cutlass::platform::is_same::value, + "update to support non-float accum"); + // https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-884-f16 + // T0 & T2 share same line within a quad + auto otherV = __shfl_xor_sync(0xffffffff, myValue, 1 << 1); + myValue = fn(myValue, otherV); + // quad 0 and quad 2 are on the same lines + otherV = __shfl_xor_sync(0xffffffff, myValue, 1 << 3); + myValue = fn(myValue, otherV); + return (lane_id & ((1 << 1) | (1 << 3))) == 0; + } + + template + CUTLASS_DEVICE static void iterateRows( + cutlass::MatrixCoord& lane_offset, + FA beginRow, + FB op, + FC endRow) { + CUTLASS_PRAGMA_UNROLL + for (int tile_m = 0; tile_m < Policy::TileIterations::kRow; ++tile_m) { + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < EleShapePerPatial::kRow; ++m) { + int accum_m = tile_m * Policy::InterleavedTile::kRow + + mma_m * QuadShapePerPatialMma::kRow + m * 2 + lane_offset.row(); + beginRow(accum_m); + + CUTLASS_PRAGMA_UNROLL + for (int tile_n = 0; tile_n < Policy::TileIterations::kColumn; + ++tile_n) { + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; + ++mma_n) { + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < kAccumulatorPatials; ++p) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < EleShapePerPatial::kColumn; ++n) { + int mma_accum_start = + (((tile_n * Policy::TileIterations::kRow + tile_m) * + Policy::MmaIterations::kColumn + + mma_n) * + Policy::MmaIterations::kRow + + mma_m) * + kElementsPerMma; + int accum_n = tile_n * Policy::InterleavedTile::kColumn + + mma_n * QuadShapePerPatialMma::kColumn + + p * Policy::InterleavedTile::kColumn / 2 + n + + lane_offset.column(); + int idx = mma_accum_start + p * kElementsPerPartial + + m * EleShapePerPatial::kColumn + n; + op(accum_m, accum_n, idx); + } + } + } + } + endRow(accum_m); + } + } + } + } +}; + +template +struct AccumLambdaIteratorSimt { + using Policy = typename T::Policy; + using Iterations = typename T::Iterations; + using Element = typename T::Element; + using Delta = typename T::Delta; + using Shape = typename T::Shape; + static_assert( + cutlass::platform:: + is_same::value, + "only RowMajor is supported"); + + template + CUTLASS_DEVICE static bool reduceSameRow(int lane_id, DT& myValue, F fn) { + CUTLASS_PRAGMA_UNROLL + for (int bit = 1; bit < Policy::WarpShape::kColumn; bit *= 2) { + auto otherV = __shfl_xor_sync(0xffffffff, myValue, bit); + myValue = fn(myValue, otherV); + } + return (lane_id & (Policy::WarpShape::kColumn - 1)) == 0; + } + + template + CUTLASS_DEVICE static void iterateRows( + cutlass::MatrixCoord& lane_offset, + FA beginRow, + FB op, + FC endRow) { + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < Policy::LaneMmaShape::kM; ++m) { + int accum_m = mma_m * Delta::kRow + m + lane_offset.row(); + beginRow(accum_m); + + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) { + int accum_n = + mma_n * Policy::WarpShape::kColumn * Policy::LaneMmaShape::kN + + lane_offset.column(); + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Policy::LaneMmaShape::kN; ++n) { + int idx = n + + Policy::LaneMmaShape::kN * + (mma_n + + Iterations::kColumn * + (m + mma_m * Policy::LaneMmaShape::kM)); + op(accum_m, accum_n + n, idx); + } + } + endRow(accum_m); + } + } + } + + static cutlass::MatrixCoord CUTLASS_DEVICE get_lane_offset( + int8_t lane_id, + int8_t warp_id, + typename T::TensorCoord const& tile_offset) { + static_assert( + cutlass::platform::is_same< + typename Policy::LaneLayout, + cutlass::layout::RowMajorInterleaved<1>>::value, + ""); + typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); + + cutlass::MatrixCoord lane_offset = lane_layout.inverse(lane_id) * + cutlass::MatrixCoord(Policy::LaneMmaShape::kM, + Policy::LaneMmaShape::kN); + return lane_offset + + tile_offset * cutlass::MatrixCoord(Shape::kRow, Shape::kColumn); + } +}; + +template +struct DefaultMmaAccumLambdaIterator; + +// Simt +template +struct DefaultMmaAccumLambdaIterator< + cutlass::gemm::warp::MmaSimtTileIterator< + S, + cutlass::gemm::Operand::kC, + accum_t, + cutlass::layout::RowMajor, + P, + 1, + 1>, + accum_t, + kWarpSize> { + using WarpIterator = typename cutlass::gemm::warp::MmaSimtTileIterator< + S, + cutlass::gemm::Operand::kC, + accum_t, + cutlass::layout::RowMajor, + P, + 1, + 1>; + using Iterator = AccumLambdaIteratorSimt; +}; + +// TensorOp - Volta +template +struct DefaultMmaAccumLambdaIterator< + cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator< + S1, + accum_t, + cutlass::layout::RowMajor, + S2, + cutlass::MatrixShape<1, 1>>, + accum_t, + kWarpSize> { + using WarpIterator = + typename cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator< + S1, + accum_t, + cutlass::layout::RowMajor, + S2, + cutlass::MatrixShape<1, 1>>; + using Iterator = AccumLambdaIteratorSm70; +}; + +// TensorOp - Sm75+ +template < + typename S1, + typename S2, + typename S3, + typename accum_t, + int kWarpSize> +struct DefaultMmaAccumLambdaIterator< + cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator< + S1, + accum_t, + cutlass::layout::RowMajor, + S2, + S3>, + accum_t, + kWarpSize> { + using WarpIterator = + typename cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator< + S1, + accum_t, + cutlass::layout::RowMajor, + S2, + S3>; + using Iterator = AccumLambdaIteratorSm80; +}; diff --git a/examples/41_fused_multi_head_attention/gemm/mma_from_smem.h b/examples/41_fused_multi_head_attention/gemm/mma_from_smem.h new file mode 100644 index 0000000000..3e41274349 --- /dev/null +++ b/examples/41_fused_multi_head_attention/gemm/mma_from_smem.h @@ -0,0 +1,1955 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Tools and utils to store a GEMM output in shmem, and to use that + output as operandA for another GEMM back-to-back +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" +#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" +#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" +#include "cutlass/functional.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" +#include "cutlass/platform/platform.h" +#include "cutlass/transform/threadblock/vector_iterator.h" + +#include "../epilogue/epilogue_thread_apply_logsumexp.h" +#include "../gemm/mma_accum_lambda_iterator.h" +#include "../gemm_kernel_utils.h" +#include "../iterators/default_warp_iterator_from_smem.h" +#include "../iterators/make_residual_last.h" +#include "../iterators/transpose_warp_iterator.h" +#include "../iterators/warp_iterator_from_smem.h" +#include "cutlass/epilogue/threadblock/epilogue_smem_accumulator.h" +#include "cutlass/gemm/threadblock/mma_base.h" +#include "cutlass/gemm/threadblock/mma_multistage.h" +#include "cutlass/gemm/threadblock/mma_pipelined.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { + +/// Shared storage object needed by accumulator +/// From 13_two_tensor_op_fusion/threadblock/b2b_mma_base_smem_accumulator.h +template < + typename Shape_, + typename Element_, + typename Layout_, + typename Padding_> +class AccumulatorSharedStorage { + public: + // + // Type definitions + // + using Shape = Shape_; + using Element = Element_; + using Layout = Layout_; + using Padding = Padding_; + + /// Tensor reference to the accumulator + using TensorRefAccum = cutlass::TensorRef; + + /// Shape of the accumulator matrix in shared memory + using ShapeAccum = cutlass:: + MatrixShape; + + public: + // + // Data members + // + + /// Buffer for accumulator + cutlass::AlignedBuffer accum; + + public: + // + // Methods + // + + /// Returns a layout object for the Accum matrix + CUTLASS_DEVICE + static Layout LayoutAccum() { + return Layout::packed({ShapeAccum::kRow, ShapeAccum::kColumn}); + } + + /// Returns a TensorRef to the Accumulator + CUTLASS_HOST_DEVICE + TensorRefAccum accum_ref() { + return TensorRefAccum{accum.data(), LayoutAccum()}; + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// Taken from +// https://github.com/NVIDIA/cutlass/blob/master/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_base_smem_accumulator.h +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + // Maximum K dimension - also the dimension of the shared-memory + // holding `OperandA` + int kMaxK_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Layout in shared-memory of operand A + typename SmemLayoutA, + /// Used for partial specialization + typename Enable = bool> +class MmaBaseFromSharedMemory { + public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + static constexpr int kMaxK = kMaxK_; + + ///< Policy describing tuning details + using Policy = Policy_; + + // + // Dependent types + // + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Shape describing the overall GEMM computed from shared memory + /// by each warp. + using WarpGemm = typename Policy::Operator::Shape; + + /// Shape describing the number of warps filling the CTA + using WarpCount = GemmShape< + Shape::kM / WarpGemm::kM, + Shape::kN / WarpGemm::kN, + Shape::kK / WarpGemm::kK>; + using WarpCount1 = WarpCount; + + /// Number of warp-level GEMM oeprations + static int const kWarpGemmIterations = + (WarpGemm::kK / Operator::Policy::MmaShape::kK); + static int const kWarpGemmIterations1 = kWarpGemmIterations; + + /// Number of stages + static int const kStages = Stages; + + /// If this is true, we fill the entire shmem buffer at start + /// and don't need to iterate through it in a circular fashion + static bool const kSmemContainsEntireB = kMaxK <= Shape::kK * kStages; + + /// Tensor reference to the A operand + using TensorRefA = TensorRef; + + /// Tensor reference to the B operand + using TensorRefB = + TensorRef; + + // + // Nested structs + // + + /// Shared storage object needed by threadblock-scoped GEMM + class SharedStorage { + public: + // + // Type definitions + // + + /// Shape of the B matrix operand in shared memory + using ShapeB = MatrixShape< + Shape::kK * kStages + Policy::SmemPaddingB::kRow, + Shape::kN + Policy::SmemPaddingB::kColumn>; + + public: + // + // Data members + // + + /// Buffer for B operand + AlignedBuffer operand_B; + + public: + // + // Methods + // + + /// Returns a layout object for the B matrix + CUTLASS_HOST_DEVICE + static typename Operator::LayoutB LayoutB() { + return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); + } + + /// Returns a TensorRef to the B operand + CUTLASS_HOST_DEVICE + TensorRefB operand_B_ref() { + return TensorRefB{operand_B.data(), LayoutB()}; + } + }; + + protected: + // + // Data members + // + + // /// Iterator to load a warp-scoped tile of A operand from shared memory + // typename Operator::IteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of B operand from shared memory + typename Operator::IteratorB warp_tile_iterator_B_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + MmaBaseFromSharedMemory( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + TensorRefB& b_tile, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : warp_tile_iterator_B_(b_tile, lane_idx) {} +}; + +namespace { + +// has necessary trait compliance with WarpIteratorFromSmem but doesn't do +// anything, can be default initialized, and uses fragment that takes up +// (almost) no space. this warp iterator is selected at compile time when +// elementwise on-the-fly scaling for operand A is disabled, in which case +// operations related to loading scale factors for operand A get wiped out by +// the compiler. +template +class NoOpWarpIteratorScale { + public: + // in pipelined+multistage MMA implementations we keep an array of fragments. + // if we aren't using scaling we don't want to waste registers on fragments + // of scale elements, so ideally this would be sized 0. + // Since arrays of zero-sized objects are not allowed, using size as 1. + // The compiler will most likely wipe it out anyways. + using Fragment = cutlass::Array; + + CUTLASS_HOST_DEVICE + NoOpWarpIteratorScale() {} + + CUTLASS_HOST_DEVICE + NoOpWarpIteratorScale(TensorRef const&, int) {} + + CUTLASS_HOST_DEVICE + NoOpWarpIteratorScale& add_tile_offset( + typename TensorRef::TensorCoord const&) { + return *this; + } + + CUTLASS_HOST_DEVICE + NoOpWarpIteratorScale& operator++() { + return *this; + } + + CUTLASS_DEVICE + void load(Fragment&) const {} +}; + +// if scaling is enabled, performs fragment elementwise multiplication between +// fragment and its scaling factor. +template +class FragmentElementwiseScaler; + +// specialization for scaling being enabled. +template +class FragmentElementwiseScaler { + public: + // cast scale_frag to correct type then apply elementwise to fragment + CUTLASS_DEVICE + static Fragment apply(Fragment frag, FragmentScale const& scale_frag) { + Fragment converted_scale_frag = cutlass::NumericArrayConverter< + typename Fragment::Element, + typename FragmentScale::Element, + FragmentScale::kElements>()(scale_frag); + return cutlass::multiplies()(frag, converted_scale_frag); + } +}; + +// specialization for scaling being disabled. doesn't do anything and should +// just get wiped out by the compiler. +template +class FragmentElementwiseScaler { + public: + CUTLASS_DEVICE + static Fragment apply(Fragment frag, FragmentScale const&) { + return frag; + } +}; +} // namespace + +//////////////////////////////////////////////////////////////////////////////// +// Taken from +// https://github.com/NVIDIA/cutlass/blob/master/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_pipelined_smem_accumulator.h +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + // BEGIN smem + /// Iterates over the intermediate accumulator tile in shared memory + typename WarpIteratorA_, + /// whether or not to perform elementwise multiplication of A + // by another matrix (A_scale) that is also kept in shared memory prior + // to matmul A @ B + bool ScaleOperandA_, + /// Max GEMM problem size in K dimension + int MaxK, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Transformation applied to B operand + typename TransformB_ = NumericArrayConverter< + typename SmemIteratorB_::Element, + typename IteratorB_::Element, + IteratorB_::Fragment::kElements>, + /// Used for partial specialization + typename Enable = bool> +class MmaPipelinedFromSharedMemory : public MmaBaseFromSharedMemory< + Shape_, + MaxK, + Policy_, + 2, + typename WarpIteratorA_::Layout> { + public: + ///< Base class + using Base = MmaBaseFromSharedMemory< + Shape_, + MaxK, + Policy_, + 2, + typename WarpIteratorA_::Layout>; + + using Shape = + Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + static constexpr bool ScaleOperandA = ScaleOperandA_; + + using WarpIteratorA = WarpIteratorA_; + ///< loads fragments of A_scale from shared memory if operand A scaling is + ///< enabled. otherwise no-op. + using WarpIteratorAScale = typename cutlass::platform::conditional< + ScaleOperandA, + WarpIteratorA, + NoOpWarpIteratorScale>::type; + + using IteratorB = + IteratorB_; ///< Iterates over tiles of B operand in global memory + using ElementC = ElementC_; ///< Data type of accumulator matrix + using LayoutC = LayoutC_; ///< Layout of accumulator matrix + using Policy = Policy_; ///< Policy describing tuning details + + using SmemIteratorB = SmemIteratorB_; + + using TransformB = TransformB_; + + // + // Dependent types + // + + /// Fragment of operand B loaded from global memory + using FragmentB = typename IteratorB::Fragment; + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Obtain the arch tag from the warp-level operator + using ArchTag = typename Policy::Operator::ArchTag; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + // staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline) + static_assert( + (Base::kStages == 2), + "MmaPipelined requires kStages set to value 2"); + + private: + using WarpFragmentA = typename Operator::FragmentA; + + /// fragment type of OperandA elementwise scaling matrix. (almost) empty + /// if operand A scaling is disabled. + using WarpFragmentAScale = typename WarpIteratorAScale::Fragment; + + using WarpFragmentB = typename Operator::FragmentB; + + /// applies scaling factor to operand A fragment if operand A scaling is + /// enabled. otherwise no-op. + using FragmentAScaler = FragmentElementwiseScaler< + WarpFragmentA, + WarpFragmentAScale, + ScaleOperandA>; + + protected: + // /// Iterator to write threadblock-scoped tile of A operand to shared memory + // SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + + /// Iterator to load a warp-scoped tile of A operand from intermediate + /// accumulator tile + WarpIteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of A_scale from intermediate + /// accumulator tile (only used if ScaleOperandA_ is true) + WarpIteratorAScale warp_tile_iterator_A_scale_; + + public: + /// constructor for MMA with operand A scaling enabled. + CUTLASS_DEVICE + MmaPipelinedFromSharedMemory( + typename Base::TensorRefA a, // Operand A in shared memory + typename Base::TensorRefA a_scale, // Operand A_scale in shared memory + typename Base::TensorRefB + b_staging, // staging memory for loading tiles of B + int thread_idx, + int warp_idx, + int lane_idx) + : Base(b_staging, thread_idx, warp_idx, lane_idx), + warp_tile_iterator_A_(a, lane_idx), + warp_tile_iterator_A_scale_(a_scale, lane_idx), + smem_iterator_B_(b_staging, thread_idx) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_A_scale_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + + /// Construct from tensor references + CUTLASS_DEVICE + MmaPipelinedFromSharedMemory( + typename Base::TensorRefA a, ///< Operand A in shared memory + typename Base::TensorRefB b_staging, ///< staging memory for loading B + int thread_idx, ///< ID within the threadblock + int warp_idx, ///< ID of warp + int lane_idx) ///< ID of each thread within a warp + : Base(b_staging, thread_idx, warp_idx, lane_idx), + warp_tile_iterator_A_(a, lane_idx), + smem_iterator_B_(b_staging, thread_idx) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + + // For API compatibility with MmaMultistageFromSharedMemory + // but not supported as it worsens perf: older gpus < sm80 don't + // support async tranfers and have to waste registers + CUTLASS_DEVICE + void set_prologue_done(bool value) {} + CUTLASS_DEVICE + static void prologue( + typename Base::SharedStorage& shared_storage, + IteratorB iterator_B1, + int thread_idx, + int problem_size_0_n) {} + + CUTLASS_DEVICE + static void drain_cp_asyncs() {} + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + int gemm_k_iterations, ///< number of iterations of the mainloop + FragmentC& accum, ///< destination accumulator tile + // IteratorA iterator_A, ///< iterator over A + // operand in global memory + IteratorB iterator_B, ///< iterator over B operand in global memory + FragmentC const& src_accum, ///< source accumulator tile + // TransformA transform_A = TransformA(), ///< transformation + // applied to A fragment + TransformB transform_B = + TransformB()) { ///< transformation applied to B fragment + + // + // Prologue + // + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + FragmentB tb_frag_B; + + tb_frag_B.clear(); + + // The last kblock is loaded in the prolog + iterator_B.set_residual_tile(gemm_k_iterations == 1); + iterator_B.load(tb_frag_B); + + ++iterator_B; + + this->smem_iterator_B_.store(transform_B(tb_frag_B)); + + ++this->smem_iterator_B_; + + __syncthreads(); + + // remember that WarpFragmentAScale and WarpIteratorAScale are empty/no-op + // if scaling is disabled. + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentAScale warp_frag_A_scale[2]; + WarpFragmentB warp_frag_B[2]; + warp_frag_A[0].clear(); + warp_frag_A_scale[0].clear(); + warp_frag_B[0].clear(); + + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_A_scale_.load(warp_frag_A_scale[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_A_scale_; + ++this->warp_tile_iterator_B_; + + Operator warp_mma; + + int smem_write_stage_idx = 1; + + // Avoid reading out of bounds + iterator_B.set_residual_tile(gemm_k_iterations == 2); + iterator_B.clear_mask(gemm_k_iterations <= 1); + + // Issue loads during the first warp-level matrix multiply-add *AFTER* + // issuing shared memory loads (which have the tightest latency + // requirement). + + // + // Mainloop + // + + // Note: The main loop does not support Base::kWarpGemmIterations == 2. + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > 0; --gemm_k_iterations) { + // + // Loop over GEMM K dimension + // + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + bool hasNext = true; + + if (warp_mma_k == Base::kWarpGemmIterations - 1) { + if (gemm_k_iterations > 1) { + // Write fragments to shared memory + this->smem_iterator_B_.store(transform_B(tb_frag_B)); + } + + __syncthreads(); + + ++this->smem_iterator_B_; + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory SMEM: Don't reset iterator A, as + // we are continuing our iteration at this point + if (smem_write_stage_idx == 1) { + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + } else { + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * + Base::kWarpGemmIterations, + 0}); + } + + smem_write_stage_idx ^= 1; + hasNext = gemm_k_iterations > 1; + } + + // Only read the next if we need to + if (hasNext) { + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_mma_k + 1) % Base::kWarpGemmIterations); + + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_A_scale_.load( + warp_frag_A_scale[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_A_scale_; + ++this->warp_tile_iterator_B_; + + if (warp_mma_k == 0) { + iterator_B.load(tb_frag_B); + + ++iterator_B; + + // Avoid reading out of bounds if this was the last loop iteration + iterator_B.set_residual_tile(gemm_k_iterations == 3); + iterator_B.clear_mask(gemm_k_iterations <= 2); + } + } + + warp_mma( + accum, + FragmentAScaler::apply( + warp_frag_A[warp_mma_k % 2], warp_frag_A_scale[warp_mma_k % 2]), + warp_frag_B[warp_mma_k % 2], + accum); + } + } + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// Taken from +// https://github.com/NVIDIA/cutlass/blob/master/examples/13_two_tensor_op_fusion/threadblock/b2b_mma_multistage_smem_accumulator.h +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape1_, + /// Iterates over the intermediate accumulator tile in shared memory + typename WarpIteratorA1_, + /// whether or not to perform elementwise multiplication of A + // by another matrix (A_scale) that is also kept in shared memory prior + // to matmul A @ B + bool ScaleOperandA_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB1_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB1_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB1, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy1_, + /// Number of stages, + int Stages_, + int kMaxK_, + /// Used for partial specialization + typename Enable = bool> +class MmaMultistageFromSharedMemory : public MmaBaseFromSharedMemory< + Shape1_, + kMaxK_, + Policy1_, + Stages_, + typename WarpIteratorA1_::Layout> { + public: + ///< Base class + using Base = MmaBaseFromSharedMemory< + Shape1_, + kMaxK_, + Policy1_, + Stages_, + typename WarpIteratorA1_::Layout>; + + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape1 = Shape1_; + ///< Iterates over tiles of B operand in global memory + using IteratorB1 = IteratorB1_; + using IteratorB = IteratorB1; + ///< Policy describing tuning details + using Policy1 = Policy1_; + + using SmemIteratorB1 = SmemIteratorB1_; + using WarpIteratorA1 = WarpIteratorA1_; ///< Iterates over the intermediate + ///< accumulator tile in shared memory + static constexpr bool ScaleOperandA = ScaleOperandA_; + + ///< warp level iterator over A_scale matrix tile kept in shared memory. + ///< if elementwise A scaling is disabled then everything this does is no-op. + using WarpIteratorAScale = typename cutlass::platform::conditional< + ScaleOperandA, + WarpIteratorA1, + NoOpWarpIteratorScale>::type; + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpB1 = CacheOpB1; + static constexpr bool kSmemContainsEntireB = Base::kSmemContainsEntireB; + + // + // Dependent types + // + + /// Fragment of accumulator tile + using FragmentC1 = typename Policy1::Operator::FragmentC; + using FragmentC = FragmentC1; + + /// Warp-level Mma + using Operator1 = typename Policy1::Operator; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + /// Complex transform on B operand + static ComplexTransform const kTransformB1 = Operator1::kTransformB; + + /// Internal structure exposed for introspection. + struct Detail { + static_assert( + Base::kWarpGemmIterations1 > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + /// Number of cp.async instructions to load one stage of operand B + static int const TBLoadIterationsB1 = + IteratorB1::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB1 = + (TBLoadIterationsB1 + Base::kWarpGemmIterations1 - 1) / + Base::kWarpGemmIterations1; + }; + + static constexpr int kNumStagesConcurrentLoad = + kSmemContainsEntireB ? Base::kStages : Base::kStages - 1; + + private: + using WarpLoadedFragmentA1 = typename Operator1::FragmentA; + /// fragment of OperandA scale matrix. if operand A scaling is disabled this + /// is (almost) empty. + using WarpLoadedFragmentA1Scale = typename WarpIteratorAScale::Fragment; + using WarpLoadedFragmentB1 = typename Operator1::FragmentB; + using WarpTransformedFragmentA1 = typename Operator1::TransformedFragmentA; + using WarpTransformedFragmentB1 = typename Operator1::TransformedFragmentB; + + /// applies elementwise scaling to fragment of A. if operand A scaling is + /// disabled this is a no-op. + using FragmentAScaler = FragmentElementwiseScaler< + WarpLoadedFragmentA1, + WarpLoadedFragmentA1Scale, + ScaleOperandA>; + + private: + // + // Data members + // + + /// Iterator to load a warp-scoped tile of A1 operand from intermediate + /// accumulator tile + WarpIteratorA1 warp_tile_iterator_A1_; + + /// Iterator to load a warp-scoped tile of A1_scale operand from shared memory + /// if operand A scaling is disabled everything this does is a no-op. + WarpIteratorAScale warp_tile_iterator_A1_scale_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB1 smem_iterator_B1_; + + bool prologue_done_; + + public: + /// constructor for MMA with operand A scaling enabled. + CUTLASS_DEVICE + MmaMultistageFromSharedMemory( + typename Base::TensorRefA a, + typename Base::TensorRefA a_scale, + typename Base::TensorRefB b_tile, + int thread_idx, + int warp_idx, + int lane_idx) + : Base(b_tile, thread_idx, warp_idx, lane_idx), + warp_tile_iterator_A1_(a, lane_idx), + warp_tile_iterator_A1_scale_(a_scale, lane_idx), + smem_iterator_B1_(b_tile, thread_idx), + prologue_done_(false) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + int warp_idx_mn_1 = + warp_idx % (Base::WarpCount1::kM * Base::WarpCount1::kN); + int warp_idx_k_1 = warp_idx / (Base::WarpCount1::kM * Base::WarpCount1::kN); + int warp_idx_m_1 = warp_idx_mn_1 % Base::WarpCount1::kM; + int warp_idx_n_1 = warp_idx_mn_1 / Base::WarpCount1::kM; + + // Add per-warp offsets in units of warp-level tiles + warp_tile_iterator_A1_.add_tile_offset( + {warp_idx_m_1, Base::kWarpGemmIterations1 * warp_idx_k_1}); + warp_tile_iterator_A1_scale_.add_tile_offset( + {warp_idx_m_1, Base::kWarpGemmIterations1 * warp_idx_k_1}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations1 * warp_idx_k_1, warp_idx_n_1}); + } + + /// Construct from tensor references + CUTLASS_DEVICE + MmaMultistageFromSharedMemory( + typename Base::TensorRefA a, + typename Base::TensorRefB b_tile, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : Base(b_tile, thread_idx, warp_idx, lane_idx), + warp_tile_iterator_A1_(a, lane_idx), + smem_iterator_B1_(b_tile, thread_idx), + prologue_done_(false) { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn_1 = + warp_idx % (Base::WarpCount1::kM * Base::WarpCount1::kN); + int warp_idx_k_1 = warp_idx / (Base::WarpCount1::kM * Base::WarpCount1::kN); + + int warp_idx_m_1 = warp_idx_mn_1 % Base::WarpCount1::kM; + int warp_idx_n_1 = warp_idx_mn_1 / Base::WarpCount1::kM; + + // Add per-warp offsets in units of warp-level tiles + warp_tile_iterator_A1_.add_tile_offset( + {warp_idx_m_1, Base::kWarpGemmIterations1 * warp_idx_k_1}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations1 * warp_idx_k_1, warp_idx_n_1}); + } + + CUTLASS_DEVICE + void set_prologue_done(bool value) { + prologue_done_ = value; + } + + CUTLASS_DEVICE + static void prologue( + typename Base::SharedStorage& shared_storage, + IteratorB iterator_B1, + int thread_idx, + int problem_size_0_n) { + SmemIteratorB1 smem_iterator_B1(shared_storage.operand_B_ref(), thread_idx); + _prologue( + iterator_B1, + (problem_size_0_n + Base::Shape::kK - 1) / Base::Shape::kK, + smem_iterator_B1); + } + + CUTLASS_DEVICE + static void drain_cp_asyncs() { + // commit and drain all pending and predicated cp.async pnz from the GEMM + // mainloop + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance_1( + IteratorB1& iterator_B1, + int group_start_B1 = 0) { + iterator_B1.set_iteration_index( + group_start_B1 * IteratorB1::kAccessesPerVector); + this->smem_iterator_B1_.set_iteration_index(group_start_B1); + + // Load for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB1; ++j) { + if (group_start_B1 + j < Detail::TBLoadIterationsB1) { + typename IteratorB1::AccessType* dst_ptr = + reinterpret_cast( + this->smem_iterator_B1_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB1::ThreadMap::kElementsPerAccess / + IteratorB1::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B1.get(); + + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_B1.valid()); + + ++iterator_B1; + } + ++this->smem_iterator_B1_; + } + } + } + + CUTLASS_DEVICE + static void _prologue( + IteratorB& iterator_B1, + int32_t gemm_k_iterations_1, + SmemIteratorB1& smem_iterator_B1_) { + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < kNumStagesConcurrentLoad; + ++stage, --gemm_k_iterations_1) { + iterator_B1.set_residual_tile(gemm_k_iterations_1 == 1); + iterator_B1.clear_mask(gemm_k_iterations_1 == 0); + + iterator_B1.set_iteration_index(0); + smem_iterator_B1_.set_iteration_index(0); + + // Load for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::TBLoadIterationsB1; ++j) { + typename IteratorB1::AccessType* dst_ptr = + reinterpret_cast( + smem_iterator_B1_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorB1::ThreadMap::kElementsPerAccess / + IteratorB1::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B1.get(), iterator_B1.valid()); + + ++iterator_B1; + } + + ++smem_iterator_B1_; + } + + // Move to the next stage + iterator_B1.add_tile_offset({1, 0}); + + smem_iterator_B1_.add_tile_offset({1, 0}); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + iterator_B1.set_residual_tile(gemm_k_iterations_1 == 1); + iterator_B1.clear_mask(gemm_k_iterations_1 == 0); + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations_1_, + ///< destination accumulator tile + FragmentC1& accum, + ///< iterator over B1 operand in global memory + IteratorB1 iterator_B1, + ///< initial value of accumulator + FragmentC1 const& src_accum) { + // 2nd Gemm + + // + // Prologue + // + // Perform accumulation in the 'd' output operand + accum = src_accum; + + if (!prologue_done_) { + _prologue(iterator_B1, gemm_k_iterations_1_, smem_iterator_B1_); + } else if (!kSmemContainsEntireB) { + // Restore the iterators increments + + int gemm_k_iterations_1 = gemm_k_iterations_1_; + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < kNumStagesConcurrentLoad; + ++stage, --gemm_k_iterations_1) { + iterator_B1.set_iteration_index(0); + this->smem_iterator_B1_.set_iteration_index(0); + + // Load for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::TBLoadIterationsB1; ++j) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { + ++iterator_B1; + } + ++this->smem_iterator_B1_; + } + iterator_B1.add_tile_offset({1, 0}); + this->smem_iterator_B1_.add_tile_offset({1, 0}); + } + iterator_B1.set_residual_tile(gemm_k_iterations_1 <= 1); + iterator_B1.clear_mask(gemm_k_iterations_1 <= 0); + } + + // DEPBAR+SYNC + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // remember that WarpFragmentAScale and WarpIteratorAScale are no-op/empty + // if scaling is disabled. + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpLoadedFragmentA1 warp_loaded_frag_A1[2]; + WarpLoadedFragmentA1Scale warp_loaded_frag_A1_scale[2]; + WarpLoadedFragmentB1 warp_loaded_frag_B1[2]; + WarpTransformedFragmentA1 warp_transformed_frag_A1[2]; + WarpTransformedFragmentB1 warp_transformed_frag_B1[2]; + + Operator1 warp_mma1; + + warp_tile_iterator_A1_.load(warp_loaded_frag_A1[0]); + ++warp_tile_iterator_A1_; + + warp_tile_iterator_A1_scale_.load(warp_loaded_frag_A1_scale[0]); + ++warp_tile_iterator_A1_scale_; + + this->warp_tile_iterator_B_.set_kgroup_index(0); + this->warp_tile_iterator_B_.load(warp_loaded_frag_B1[0]); + ++this->warp_tile_iterator_B_; + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + warp_mma1.transform( + warp_transformed_frag_A1[0], + warp_transformed_frag_B1[0], + FragmentAScaler::apply( + warp_loaded_frag_A1[0], warp_loaded_frag_A1_scale[0]), + warp_loaded_frag_B1[0]); + + // tf32x3 kernels use staging accumulation. warp_mma uses a temporary + // accumulator and this temporary accumulator is added to the final + // accumulator once in every mainloop iteration. + plus plus_accum; + + FragmentC1 tmp_accum; + + if (platform::is_same< + typename Operator1::MathOperator, + arch::OpMultiplyAddFastF32>::value || + platform::is_same< + typename Operator1::MathOperator, + arch::OpMultiplyAddComplexFastF32>::value) { + tmp_accum.clear(); + } + + // + // Mainloop + // + + CUTLASS_PRAGMA_UNROLL + for (int gemm_k_iterations_1 = gemm_k_iterations_1_ - (Base::kStages - 1); + gemm_k_iterations_1 > (-Base::kStages + 1); + gemm_k_iterations_1--) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; + ++warp_mma_k) { + // Load warp-level tile from accumulator fragment (A) + // or shared memory (operand B) + this->warp_tile_iterator_B_.set_kgroup_index( + (warp_mma_k + 1) % Base::kWarpGemmIterations1); + // skip warp tile loading for the last kgroup (we are out of the buf) + if (gemm_k_iterations_1 > (-Base::kStages + 2) || + warp_mma_k < Base::kWarpGemmIterations1 - 1) { + warp_tile_iterator_A1_.load( + warp_loaded_frag_A1[(warp_mma_k + 1) % 2]); + warp_tile_iterator_A1_scale_.load( + warp_loaded_frag_A1_scale[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B_.load( + warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); + } + ++warp_tile_iterator_A1_; + ++warp_tile_iterator_A1_scale_; + ++this->warp_tile_iterator_B_; + + if (warp_mma_k > 0) + warp_mma1.transform( + warp_transformed_frag_A1[warp_mma_k % 2], + warp_transformed_frag_B1[warp_mma_k % 2], + FragmentAScaler::apply( + warp_loaded_frag_A1[warp_mma_k % 2], + warp_loaded_frag_A1_scale[warp_mma_k % 2]), + warp_loaded_frag_B1[warp_mma_k % 2]); + + if (platform::is_same< + typename Operator1::MathOperator, + arch::OpMultiplyAddFastF32>::value || + platform::is_same< + typename Operator1::MathOperator, + arch::OpMultiplyAddComplexFastF32>::value) { + warp_mma1( + tmp_accum, + warp_transformed_frag_A1[warp_mma_k % 2], + warp_transformed_frag_B1[warp_mma_k % 2], + tmp_accum); + + if (warp_mma_k == 0) { + accum = plus_accum(accum, tmp_accum); + tmp_accum.clear(); + } + } else { + warp_mma1( + accum, + warp_transformed_frag_A1[warp_mma_k % 2], + warp_transformed_frag_B1[warp_mma_k % 2], + accum); + } + + // Issue global->shared copies for the this stage + if (warp_mma_k < Base::kWarpGemmIterations1 - 1) { + int group_start_iteration_B1; + + group_start_iteration_B1 = warp_mma_k * Detail::kAccessesPerGroupB1; + + if (!kSmemContainsEntireB) { + copy_tiles_and_advance_1(iterator_B1, group_start_iteration_B1); + } + } + + if (warp_mma_k + 2 == Base::kWarpGemmIterations1) { + int group_start_iteration_B1; + group_start_iteration_B1 = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB1; + + if (!kSmemContainsEntireB) { + copy_tiles_and_advance_1(iterator_B1, group_start_iteration_B1); + } + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Waits until kStages-2 stages have committed. + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_B1.add_tile_offset({1, 0}); + + this->smem_iterator_B1_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (!kSmemContainsEntireB) { + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy1::kPartitionsK * + Base::kWarpGemmIterations1, + 0}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + } + + iterator_B1.set_residual_tile(gemm_k_iterations_1 == 2); + iterator_B1.clear_mask(gemm_k_iterations_1 == 1); + } + + // Do any conversions feeding the first stage at the end of the loop so + // we can start right away on mma instructions + if (warp_mma_k + 1 == Base::kWarpGemmIterations1) + warp_mma1.transform( + warp_transformed_frag_A1[(warp_mma_k + 1) % 2], + warp_transformed_frag_B1[(warp_mma_k + 1) % 2], + FragmentAScaler::apply( + warp_loaded_frag_A1[(warp_mma_k + 1) % 2], + warp_loaded_frag_A1_scale[(warp_mma_k + 1) % 2]), + warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); + } + } + + if (platform::is_same< + typename Operator1::MathOperator, + arch::OpMultiplyAddFastF32>::value || + platform::is_same< + typename Operator1::MathOperator, + arch::OpMultiplyAddComplexFastF32>::value) { + accum = plus_accum(accum, tmp_accum); + } + } +}; + +// Converts a "regular" Mma into their counterpart from shared memory +template < + typename Mma_, + int kMaxK, + typename WarpIteratorA_, + /// whether or not to apply elementwise multiplication of operand A by + /// another matrix in shared memory before usage in A @ B + bool kScaleOperandA, + bool kTransposeA = false> +struct DefaultMmaFromSharedMemory; + +// Mma pipelined +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + typename WarpIteratorA_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Transformation applied to A operand + typename TransformA_, + /// Transformation applied to B operand + typename TransformB_, + // Max MMA problem size K + int kMaxK, + /// whether or not to apply elementwise multiplication of operand A by + /// another matrix in shared memory before usage in A @ B + bool kScaleOperandA, + bool kTransposeA> +struct DefaultMmaFromSharedMemory< + MmaPipelined< + Shape_, + IteratorA_, + SmemIteratorA_, + IteratorB_, + SmemIteratorB_, + ElementC_, + LayoutC_, + Policy_, + TransformA_, + TransformB_>, + kMaxK, + WarpIteratorA_, + kScaleOperandA, + kTransposeA> { + using RegularMma = MmaPipelined< + Shape_, + IteratorA_, + SmemIteratorA_, + IteratorB_, + SmemIteratorB_, + ElementC_, + LayoutC_, + Policy_, + TransformA_, + TransformB_>; + + using WarpShape = typename Policy_::Operator::Shape; + using InstructionShape = typename Policy_::Operator::InstructionShape; + using ArchMmaOperator = typename Policy_::Operator; + + static constexpr bool kIsTransposedA = false; + using WarpIteratorA = WarpIteratorA_; + using IteratorB = + typename cutlass::transform::threadblock::MakeIteratorResidualLast< + IteratorB_>::Iterator; + + using Mma = typename cutlass::gemm::threadblock::MmaPipelinedFromSharedMemory< + Shape_, + WarpIteratorA, + kScaleOperandA, + kMaxK, + IteratorB, + SmemIteratorB_, + ElementC_, + LayoutC_, + Policy_>; +}; + +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + typename WarpIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear, + int kMaxK, + /// whether or not to apply elementwise multiplication of operand A by + /// another matrix in shared memory before usage in A @ B + bool kScaleOperandA, + bool kTransposeA> +struct DefaultMmaFromSharedMemory< + MmaMultistage< + Shape_, + IteratorA_, + SmemIteratorA_, + CacheOpA, + IteratorB_, + SmemIteratorB_, + CacheOpB, + ElementC_, + LayoutC_, + Policy_, + Stages, + SharedMemoryClear>, + kMaxK, + WarpIteratorA_, + kScaleOperandA, + kTransposeA> { + using RegularMma = MmaMultistage< + Shape_, + IteratorA_, + SmemIteratorA_, + CacheOpA, + IteratorB_, + SmemIteratorB_, + CacheOpB, + ElementC_, + LayoutC_, + Policy_, + Stages, + SharedMemoryClear>; + + using WarpShape = typename Policy_::Operator::Shape; + using InstructionShape = typename Policy_::Operator::InstructionShape; + using WarpIteratorTranspose = TransposeWarpIterator; + static constexpr bool kIsTransposedA = + WarpIteratorTranspose::kSupportsTranspose && kTransposeA; + using WarpIteratorA = typename platform::conditional< + kIsTransposedA, + typename WarpIteratorTranspose::Iterator, + WarpIteratorA_>::type; + + // Reduce the number of stages if we don't need that many + static int constexpr kStagesMax = + (kMaxK + int(Shape_::kK) - 1) / int(Shape_::kK); + static int constexpr kStages = cutlass::const_min(Stages, kStagesMax); + + using IteratorB = + typename cutlass::transform::threadblock::MakeIteratorResidualLast< + IteratorB_>::Iterator; + using Mma = + typename cutlass::gemm::threadblock::MmaMultistageFromSharedMemory< + Shape_, + WarpIteratorA, + kScaleOperandA, + IteratorB, + SmemIteratorB_, + RegularMma::kCacheOpB, + ElementC_, + LayoutC_, + Policy_, + kStages, + kMaxK>; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename IteratorC, + typename Operator, + typename scalar_t, + typename WarpShape_, + typename ThreadblockShape_> +struct B2bGemm; + +// Tensor Cores >= Sm75 specialization (Ampere ...) +template < /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Element type + typename Element_, + /// Layout of operand in memory + typename Layout_, + /// Shape of one matrix product operation (concept: MatrixShape) + typename InstructionShape_, + /// Interval between adjacent *MMA instructions (in units of MMA + /// instructions, concept: MatrixShape) + typename OpDelta_, + typename Operator, + typename scalar_t, + typename WarpShape_, + typename ThreadblockShape_> +struct B2bGemm< + cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator< + Shape_, + Element_, + Layout_, + InstructionShape_, + OpDelta_>, + Operator, + scalar_t, + WarpShape_, + ThreadblockShape_> { + using IteratorC = + typename cutlass::gemm::warp::MmaTensorOpAccumulatorTileIterator< + Shape_, + Element_, + Layout_, + InstructionShape_, + OpDelta_>; + using FragmentC = typename IteratorC::Fragment; + using InstructionShape = InstructionShape_; + using WarpShape = WarpShape_; + using ThreadblockShape = ThreadblockShape_; + using accum_t = Element_; + using lse_scalar_t = float; + + using SmemAccumulatorLayout = cutlass::layout::RowMajor; + + // Iterator to load accumulators (results of matmul in registers) + using FragmentIteratorAccumulator = + cutlass::epilogue::warp::FragmentIteratorTensorOp< + WarpShape, + InstructionShape, + accum_t, + typename Operator::Policy::Operator::FragmentC, + cutlass::layout::RowMajor>; + + // Iterator to store to shared-memory + using SmemIteratorD0 = typename cutlass::epilogue::warp::TileIteratorTensorOp< + WarpShape, + InstructionShape, + scalar_t, // accum_t, + SmemAccumulatorLayout>; + using AccumulatorSharedStorage = + cutlass::gemm::threadblock::AccumulatorSharedStorage< + ThreadblockShape, + typename SmemIteratorD0::Element, + typename SmemIteratorD0::TensorLayout, + typename SmemIteratorD0::Padding>; + // We need to provide an operation for the epilogue. Let's create an + // operation that does nothing (ScaleType::Nothing), just converts + // from accum_t (float) -> scalar_t (can be half) + using OutputOpNoOp = cutlass::epilogue::thread::LinearCombination< + typename SmemIteratorD0::Element, // ElementOutput + FragmentIteratorAccumulator::Fragment::kElements, + accum_t, // ElementAccumulator + typename SmemIteratorD0::Element, // ElementCompute + cutlass::epilogue::thread::ScaleType::Nothing>; + using Epilogue = cutlass::epilogue::threadblock::EpilogueSmemAccumulator< + SmemIteratorD0, + FragmentIteratorAccumulator, + SmemIteratorD0, // ScaleBiasIterator - not used + OutputOpNoOp>; + + // Epilogue 2: with LSE (for backwards pass) + static int const kElementsPerAccess = 2; // TODO: Why 2? + using IteratorAccumulatorLSE = + cutlass::transform::threadblock::VectorIterator< + cutlass::transform::threadblock::PredicatedVectorAccessIterator< + // Shape + cutlass::MatrixShape, + // WarpShape + cutlass::MatrixShape, + lse_scalar_t, + cutlass::layout::RowMajor, + kElementsPerAccess>>; + using EpilogueOpApplyLSE = cutlass::epilogue::thread::ApplyLogSumExp< + scalar_t, // ElementOutput_ + lse_scalar_t, // ElementLSE_ + accum_t, // ElementAccumulator_ + accum_t, // ElementCompute_ + 128 / cutlass::sizeof_bits::value + // FragmentIteratorAccumulator::Fragment::kElements + // InstructionShape::kM * InstructionShape::kN / 32 + >; + using EpilogueWithLSE = + cutlass::epilogue::threadblock::EpilogueSmemAccumulator< + SmemIteratorD0, + FragmentIteratorAccumulator, + IteratorAccumulatorLSE, + EpilogueOpApplyLSE>; + + static void CUTLASS_DEVICE accumToSmem( + AccumulatorSharedStorage& shared_storage, + FragmentC const& accum, + int lane_id, + cutlass::MatrixCoord const& tile_coords) { + SmemIteratorD0 smem_iterator_attn(shared_storage.accum_ref(), lane_id); + smem_iterator_attn.add_tile_offset( + tile_coords * + cutlass::MatrixCoord{ + SmemIteratorD0::TileIterations::kRow, + SmemIteratorD0::TileIterations::kColumn}); + Epilogue epilogue; + epilogue(OutputOpNoOp({}), smem_iterator_attn, accum); + } + + static void CUTLASS_DEVICE accumApplyLSEToSmem( + AccumulatorSharedStorage& shared_storage, + FragmentC& accum, + lse_scalar_t const* lse, + int32_t lse_extents, + int thread_id, + int warp_id, + int lane_id, + cutlass::MatrixCoord const& tile_coords) { + constexpr int32_t kAlignLSE = 32; + IteratorAccumulatorLSE iterator_lse( + lse, + {(int32_t)0, (int32_t)ceil_div(lse_extents, kAlignLSE) * kAlignLSE}, + thread_id, + warp_id, + cutlass::MatrixCoord{0, 0} // offset + ); + + SmemIteratorD0 smem_iterator_attn(shared_storage.accum_ref(), lane_id); + smem_iterator_attn.add_tile_offset( + tile_coords * + cutlass::MatrixCoord{ + SmemIteratorD0::TileIterations::kRow, + SmemIteratorD0::TileIterations::kColumn}); + EpilogueWithLSE epilogue; + EpilogueOpApplyLSE minus_lse_exp({}); + epilogue( + minus_lse_exp, + smem_iterator_attn, + accum, + // scale - unused + iterator_lse, + // bias + iterator_lse); + } +}; + +// Volta Specialization +// only supported for f16 +template +struct B2bGemm< + cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator< + cutlass::MatrixShape<32, 32>, + float, + cutlass::layout::RowMajor, + cutlass::gemm::GemmShape<16, 16, 4>, + cutlass::MatrixShape<1, 1>>, + Operator, + cutlass::half_t, + WarpShape_, + ThreadblockShape_> { + using IteratorC = + cutlass::gemm::warp::MmaVoltaTensorOpAccumulatorTileIterator< + cutlass::MatrixShape<32, 32>, + float, + cutlass::layout::RowMajor, + cutlass::gemm::GemmShape<16, 16, 4>, + cutlass::MatrixShape<1, 1>>; + using scalar_t = cutlass::half_t; + using accum_t = IteratorC::Element; + using WarpShape = WarpShape_; + using ThreadblockShape = ThreadblockShape_; + using FragmentC = IteratorC::Fragment; + using lse_scalar_t = float; + + // Storage in shared-memory for Q.Kt + using SmemAccumulatorLayout = + cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>; + using AccumulatorSharedStorage = + cutlass::gemm::threadblock::AccumulatorSharedStorage< + ThreadblockShape, + scalar_t, + SmemAccumulatorLayout, + cutlass::MatrixShape<0, 0> // Padding + >; + using TensorRef = cutlass::TensorRef; + using Policy = typename IteratorC::Policy; + using Element = accum_t; + // Those are MmaVoltaTensorOpAccumulatorTileIterator private fields + // Let's copy their values + static int const kElementsPerPartial = 4; + using EleShapePerPatial = typename cutlass::platform::conditional< + cutlass::platform::is_same::value, + cutlass::MatrixShape<2, 2>, + cutlass::MatrixShape<1, 4>>::type; + static int const kElementsPerMma = 8; + static int const kAccumulatorPatials = 2; + using QuadShapePerPatialMma = cutlass::MatrixShape<4, 4>; + + static void CUTLASS_DEVICE accumToSmem( + AccumulatorSharedStorage& shared_storage, + FragmentC const& accum, + int lane_id, + cutlass::MatrixCoord const& tile_coords) { + // ctor - from MmaVoltaTensorOpAccumulatorTileIterator + TensorRef ref_(shared_storage.accum_ref()); + int quad = (lane_id >> 2); + int lane_in_quad = (lane_id & 3); + int accum_m, accum_n; + + if (cutlass::platform::is_same::value) { + // (quad[2],quad[0])+lane_in_quad[0] + accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + (lane_in_quad & 1); + // (quad[1])+lane_in_quad[1] + accum_n = + ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials + + (lane_in_quad & 2); + } else { + accum_m = (((quad & 0x4) >> 1) + (quad & 0x1)) * 8 + + lane_in_quad; // (quad[2],quad[0]) + accum_n = ((quad >> 1) & 0x1) * kElementsPerPartial * kAccumulatorPatials; + } + cutlass::MatrixCoord lane_offset(accum_m, accum_n); + + // Tile offset + ref_.add_coord_offset( + tile_coords * + cutlass::MatrixCoord( + {IteratorC::Shape::kRow, IteratorC::Shape::kColumn})); + + using AccessType = cutlass::Array; + + // store - from MmaVoltaTensorOpAccumulatorTileIterator + CUTLASS_PRAGMA_UNROLL + for (int tile_n = 0; tile_n < Policy::TileIterations::kColumn; ++tile_n) { + CUTLASS_PRAGMA_UNROLL + for (int tile_m = 0; tile_m < Policy::TileIterations::kRow; ++tile_m) { + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Policy::MmaIterations::kColumn; ++mma_n) { + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < Policy::MmaIterations::kRow; ++mma_m) { + int mma_accum_start = + (((tile_n * Policy::TileIterations::kRow + tile_m) * + Policy::MmaIterations::kColumn + + mma_n) * + Policy::MmaIterations::kRow + + mma_m) * + kElementsPerMma; + + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < kAccumulatorPatials; ++p) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < EleShapePerPatial::kRow; ++m) { + int accum_m = tile_m * Policy::InterleavedTile::kRow + + mma_m * QuadShapePerPatialMma::kRow + m * 2; + int accum_n = tile_n * Policy::InterleavedTile::kColumn + + mma_n * QuadShapePerPatialMma::kColumn + + p * Policy::InterleavedTile::kColumn / 2; + int r = (accum_m + lane_offset.row()); + AccessType to_store; + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < EleShapePerPatial::kColumn; ++n) { + int idx = mma_accum_start + p * kElementsPerPartial + + m * EleShapePerPatial::kColumn + n; + int c = (accum_n + n + lane_offset.column()); + to_store[n] = scalar_t(accum[idx]); + } + int c = (accum_n + lane_offset.column()); + assert(r < 32); + assert(c < 32); + *reinterpret_cast( + ref_.data() + ref_.offset({r, c})) = to_store; + } + } + } + } + } + } + } + + static void CUTLASS_DEVICE accumApplyLSEToSmem( + AccumulatorSharedStorage& shared_storage, + typename IteratorC::Fragment& accum, + lse_scalar_t const* lse, + int lse_extent, + int thread_id, + int warp_id, + int lane_id, + cutlass::MatrixCoord const& tile_coords) { + // Non-optimized way to apply LSE to registers + // NOTE: accum is attn.T + // TODO: Optimize for each architecture + static constexpr int WarpSize = 32; + using AccumLambdaIterator = + typename DefaultMmaAccumLambdaIterator:: + Iterator; + auto lane_offset = + AccumLambdaIterator::get_lane_offset(lane_id, warp_id, tile_coords); + + cutlass::Array lse_prefetched; + lse_prefetched.clear(); + int rowIdx = 0; + int colIdx = 0; + AccumLambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { + ++rowIdx; + colIdx = 0; + }, + [&](int accum_m, int accum_n, int idx) { + if (rowIdx == 1) { + lse_prefetched[colIdx] = accum_n < lse_extent + ? lse[accum_n] + : cutlass::platform::numeric_limits::infinity(); + } + accum[idx] = expf(accum[idx] - lse_prefetched[colIdx]); + ++colIdx; + }, + [&](int accum_m) {}); + accumToSmem(shared_storage, accum, lane_id, tile_coords); + } +}; + +// Simt Specialization +// for f32 on Sm70-Sm75 and f16/f32 below + +template < + typename Operator, + typename OperatorPolicy, + typename scalar_t, + typename WarpShape_, + typename ThreadblockShape_> +struct B2bGemm< + cutlass::gemm::warp::MmaSimtTileIterator< + cutlass::MatrixShape<32, 32>, + cutlass::gemm::Operand::kC, + float, + cutlass::layout::RowMajor, + OperatorPolicy, + 1, + 1>, + Operator, + scalar_t, + WarpShape_, + ThreadblockShape_> { + using IteratorC = cutlass::gemm::warp::MmaSimtTileIterator< + cutlass::MatrixShape<32, 32>, + cutlass::gemm::Operand::kC, + float, + cutlass::layout::RowMajor, + OperatorPolicy, + 1, + 1>; + using accum_t = typename IteratorC::Element; + using WarpShape = WarpShape_; + using ThreadblockShape = ThreadblockShape_; + using FragmentC = typename IteratorC::Fragment; + using lse_scalar_t = float; + + // Storage in shared-memory for Q.Kt + using AccumulatorSharedStorage = + cutlass::gemm::threadblock::AccumulatorSharedStorage< + ThreadblockShape, + scalar_t, + cutlass::layout::ColumnMajor, + cutlass::MatrixShape<0, 0> // Padding + >; + + static void CUTLASS_DEVICE accumToSmem( + AccumulatorSharedStorage& shared_storage, + FragmentC const& accum, + int lane_id, + cutlass::MatrixCoord const& tile_coords) { + using Policy = typename IteratorC::Policy; + using Element = typename IteratorC::Element; + using Iterations = typename IteratorC::Iterations; + using Delta = typename IteratorC::Delta; + + auto ref_ = shared_storage.accum_ref(); + // ctor - MmaSimtTileIterator + // compute offset based on thread ID and lane layout + typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); + + MatrixCoord lane_offset = lane_layout.inverse(lane_id) * + MatrixCoord(Policy::LaneMmaShape::kM, Policy::LaneMmaShape::kN); + + ref_.add_coord_offset(lane_offset); + + // Tile offset + ref_.add_coord_offset( + tile_coords * + cutlass::MatrixCoord( + {IteratorC::Shape::kRow, IteratorC::Shape::kColumn})); + + // store - MmaSimtTileIterator + CUTLASS_PRAGMA_UNROLL + for (int mma_n = 0; mma_n < Iterations::kColumn; ++mma_n) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Policy::LaneMmaShape::kN; ++n) { + CUTLASS_PRAGMA_UNROLL + for (int mma_m = 0; mma_m < Iterations::kRow; ++mma_m) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < Policy::LaneMmaShape::kM; ++m) { + int r = + Policy::LaneMmaShape::kM * (mma_m * Policy::WarpShape::kRow) + + m; + int c = mma_n * Delta::kColumn + n; + int idx = n + + Policy::LaneMmaShape::kN * + (mma_n + + Iterations::kColumn * + (m + mma_m * Policy::LaneMmaShape::kM)); + ref_.at({r, c}) = scalar_t(accum[idx]); + } + } + } + } + } + + static void CUTLASS_DEVICE accumApplyLSEToSmem( + AccumulatorSharedStorage& shared_storage, + typename IteratorC::Fragment& accum, + lse_scalar_t const* lse, + int lse_extent, + int thread_id, + int warp_id, + int lane_id, + cutlass::MatrixCoord const& tile_coords) { + // Non-optimized way to apply LSE to registers + // NOTE: accum is attn.T + // TODO: Optimize for each architecture + static constexpr int WarpSize = 32; + using AccumLambdaIterator = + typename DefaultMmaAccumLambdaIterator:: + Iterator; + auto lane_offset = + AccumLambdaIterator::get_lane_offset(lane_id, warp_id, tile_coords); + + cutlass::Array lse_prefetched; + lse_prefetched.clear(); + int rowIdx = 0; + int colIdx = 0; + AccumLambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { + ++rowIdx; + colIdx = 0; + }, + [&](int accum_m, int accum_n, int idx) { + if (rowIdx == 1) { + lse_prefetched[colIdx] = accum_n < lse_extent + ? lse[accum_n] + : cutlass::platform::numeric_limits::infinity(); + } + accum[idx] = expf(accum[idx] - lse_prefetched[colIdx]); + ++colIdx; + }, + [&](int accum_m) {}); + accumToSmem(shared_storage, accum, lane_id, tile_coords); + } +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/41_fused_multi_head_attention/gemm_kernel_utils.h b/examples/41_fused_multi_head_attention/gemm_kernel_utils.h new file mode 100644 index 0000000000..a770e0b671 --- /dev/null +++ b/examples/41_fused_multi_head_attention/gemm_kernel_utils.h @@ -0,0 +1,258 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/arch/mma.h" + +//////////////////////////////////////////////////////////////////////////////// +// Some helper functions +//////////////////////////////////////////////////////////////////////////////// +#define DISPATCH_TYPES(tensor, func) \ + { \ + if (query.scalar_type() == at::ScalarType::Float) { \ + using scalar_t = float; \ + func(); \ + } else if (query.scalar_type() == at::ScalarType::Half) { \ + using scalar_t = cutlass::half_t; \ + func(); \ + } else if (query.scalar_type() == at::ScalarType::BFloat16) { \ + using scalar_t = cutlass::bfloat16_t; \ + func(); \ + } else { \ + XFORMERS_CHECK(false, "Only fp32, half & bf16 supported at the moment"); \ + } \ + } + +#define DISPATCH_BOOL(BOOL_V, BOOL_NAME, F) \ + { \ + if (BOOL_V) { \ + using BOOL_NAME = std::true_type; \ + F(); \ + } else { \ + using BOOL_NAME = std::false_type; \ + F(); \ + } \ + } + +#define DISPATCH_ARCHTAG(CC, func) \ + { \ + if (CC >= 80) { \ + using ArchTag = cutlass::arch::Sm80; \ + func(); \ + } else if (CC >= 75) { \ + using ArchTag = cutlass::arch::Sm75; \ + func(); \ + } else if (CC >= 70) { \ + using ArchTag = cutlass::arch::Sm70; \ + func(); \ + } else if (CC >= 50) { \ + using ArchTag = cutlass::arch::Sm50; \ + func(); \ + } else { \ + XFORMERS_CHECK( \ + false, \ + "Your device is too old. We require compute capability >= 50"); \ + } \ + } + +#define CHECK_NOSPARSE_CONTIGUOUS_CUDA(TENSOR) \ + XFORMERS_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \ + XFORMERS_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \ + XFORMERS_CHECK(TENSOR.is_contiguous()); + +#define CHECK_NOSPARSE_LASTCONTIGUOUS_CUDA(TENSOR) \ + XFORMERS_CHECK(TENSOR.is_cuda(), #TENSOR " must be a CUDA tensor"); \ + XFORMERS_CHECK(!TENSOR.is_sparse(), #TENSOR " must be a dense tensor"); \ + XFORMERS_CHECK( \ + TENSOR.stride(-1) == 1, #TENSOR ": last dimension must be contiguous"); + +#ifdef TORCH_CHECK +#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \ + XFORMERS_CHECK( \ + uint64_t(PTR) % ALIGNMENT == 0, #PTR " is not correctly aligned") +#define XFORMERS_CHECK TORCH_CHECK +#elif defined(__CUDACC_RTC__) +#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \ + if (!(uint64_t(PTR) % ALIGNMENT == 0)) { \ + return false; \ + } +#define XFORMERS_CHECK(COND, ERR) \ + if (!(COND)) { \ + return false; \ + } +#else +#include +#define CHECK_ALIGNED_PTR(PTR, ALIGNMENT) \ + if (!(uint64_t(PTR) % ALIGNMENT == 0)) { \ + std::cerr << #PTR " is not correctly aligned\n"; \ + return false; \ + } +#define XFORMERS_CHECK(COND, ERR) \ + if (!(COND)) { \ + std::cerr << "'" #COND "' failed: " << ERR << "\n"; \ + return false; \ + } +#endif + +#define ASSIGN_CHECK_OVERFLOW(A, B) \ + { \ + A = B; \ + XFORMERS_CHECK( \ + B < std::numeric_limits::max(), #B " overflows"); \ + } + +namespace gemm_kernel_utils { + +template +constexpr CUTLASS_HOST_DEVICE integer ceil_div(integer n, integer m) { + return (n + m - 1) / m; +} + +template +constexpr CUTLASS_HOST_DEVICE integer align_up(integer n, integer m) { + return ((n + m - 1) / m) * m; +} + +//////////////////////////////////////////////////////////////////////////////// +// Determine the type of GEMM we do (TensorCores or not, Shapes ...) +// TODO: Maybe we could rely on Cutlass's DefaultGemm templates +//////////////////////////////////////////////////////////////////////////////// + +// Fallback to Simt (FMA on cuda cores) if not in a special case below +template +struct DefaultGemmType { + static constexpr int ThreadK = 8; + static constexpr int WarpK = 8; + static constexpr int kMinimumAlignment = 1; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using OpClass = cutlass::arch::OpClassSimt; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +// Specialization for tensorcores with f32 +template +struct DefaultGemmType< + ArchTag, + float, + typename cutlass::platform::enable_if< + ArchTag::kMinComputeCapability >= 80>::type> { + static constexpr int ThreadK = 32; + static constexpr int WarpK = 32; + static constexpr int kMinimumAlignment = 4; + using OpClass = cutlass::arch::OpClassTensorOp; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + using Operator = cutlass::arch::OpMultiplyAddFastF32; +}; + +// Specialization for tensorcores with f16/bf16 - Sm75+ +template +struct DefaultGemmType< + ArchTag, + scalar_t, + typename cutlass::platform::enable_if< + ArchTag::kMinComputeCapability >= 75 && + cutlass::sizeof_bits::value == 16>::type> { + static constexpr int ThreadK = 32; + static constexpr int WarpK = 32; + static constexpr int kMinimumAlignment = 4; + using OpClass = cutlass::arch::OpClassTensorOp; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +// Specialization for tensorcores with f16 - Volta +template <> +struct DefaultGemmType { + static constexpr int ThreadK = 32; + static constexpr int WarpK = 32; + static constexpr int kMinimumAlignment = 2; + using OpClass = cutlass::arch::OpClassTensorOp; + using InstructionShape = cutlass::gemm::GemmShape<8, 8, 4>; + using Operator = cutlass::arch::OpMultiplyAdd; +}; + +// Enables to do +// `auto x = kCondition ? fa(arg) : fb(arg)` +// when `fa` and `fb` have different types +template +struct call_conditional; + +template +struct call_conditional { + template + static CUTLASS_HOST_DEVICE auto apply(TA ta, TB tb, Arg arg) + -> decltype(ta(arg)) { + return ta(arg); + } +}; + +template +struct call_conditional { + template + static CUTLASS_HOST_DEVICE auto apply(TA ta, TB tb, Arg arg) + -> decltype(tb(arg)) { + return tb(arg); + } +}; + +//////////////////////////////////////////////////////////////////////////////// +// Mark a variable as warp-uniform - enables some compiler optimizations +// The cheapest way to do it is just to broadcast it from lane 0 +//////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_DEVICE T warp_uniform(T value) { + struct { + union { + T value; + uint32_t asInt; + }; + } p; + p.value = value; + p.asInt = __shfl_sync(0xffffffff, (unsigned)p.asInt, 0); + return p.value; +} + +template +CUTLASS_DEVICE T* warp_uniform(T* ptr) { + struct { + union { + T* ptr; + uint32_t asInt[2]; + }; + } p; + p.ptr = ptr; + p.asInt[0] = warp_uniform(p.asInt[0]); + p.asInt[1] = warp_uniform(p.asInt[1]); + return p.ptr; +} +} // namespace gemm_kernel_utils diff --git a/examples/41_fused_multi_head_attention/iterators/default_warp_iterator_from_smem.h b/examples/41_fused_multi_head_attention/iterators/default_warp_iterator_from_smem.h new file mode 100644 index 0000000000..3dbb0cf285 --- /dev/null +++ b/examples/41_fused_multi_head_attention/iterators/default_warp_iterator_from_smem.h @@ -0,0 +1,142 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Instanciates the right WarpIterator to read from shared memory + The class `DefaultWarpIteratorAFromSharedMemory` is useful when reading + data dumped with `B2bGemm::accumToSmem`. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h" +#include "cutlass/platform/platform.h" + +#include "warp_iterator_from_smem.h" + +namespace cutlass { +namespace gemm { +namespace threadblock { + +template < + typename WarpShape, + typename InstructionShape, + typename RegularWarpIterator, + typename Policy, + typename Enable = void> +struct DefaultWarpIteratorAFromSharedMemory {}; + +// TensorOp - Ampere half +template +struct DefaultWarpIteratorAFromSharedMemory< + cutlass::gemm::GemmShape<32, 32, 32>, + cutlass::gemm::GemmShape<16, 8, kInstrK>, + RegularWarpIterator, + Policy, + typename platform::enable_if<( + sizeof_bits::value == 16 && + Policy::Operator::Policy::OpDelta::kRow == 1)>::type> { + using OpDelta = typename Policy::Operator::Policy::OpDelta; + using WarpShape = cutlass::MatrixShape<32, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, kInstrK>; + + using WarpIterator = cutlass::gemm::warp::WarpIteratorFromSmem< + cutlass::gemm::Operand::kA, + typename RegularWarpIterator::Element, + cutlass::MatrixShape>; +}; + +// TensorOp - Ampere f32 +template +struct DefaultWarpIteratorAFromSharedMemory< + WarpShape, + cutlass::gemm::GemmShape<16, 8, 8>, + RegularWarpIterator, + Policy, + typename platform::enable_if<( + sizeof_bits::value != 16 || + Policy::Operator::Policy::OpDelta::kRow != 1)>::type> { + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 8>; + static constexpr auto kWarpSize = 32; + using OpDelta = typename Policy::Operator::Policy::OpDelta; + + using WarpIterator = + cutlass::gemm::warp::MmaTensorOpMultiplicandTileAccessIterator< + cutlass::MatrixShape, + cutlass::gemm::Operand::kA, + typename RegularWarpIterator::Element, + cutlass::layout::RowMajor, + cutlass::MatrixShape, + OpDelta::kRow, + kWarpSize>; +}; + +// TensorOp - Volta +template +struct DefaultWarpIteratorAFromSharedMemory< + WarpShape, + cutlass::gemm::GemmShape<16, 16, 4>, + RegularWarpIterator, + Policy> { + using InstructionShape = cutlass::gemm::GemmShape<16, 16, 4>; + static constexpr auto kWarpSize = 32; + using OpDelta = typename Policy::Operator::Policy::OpDelta; + + using WarpIterator = + cutlass::gemm::warp::MmaVoltaTensorOpMultiplicandTileIterator< + cutlass::MatrixShape<32, 32>, // MatrixShape, + cutlass::gemm::Operand::kA, + typename RegularWarpIterator::Element, + cutlass::layout::RowMajorVoltaTensorOpMultiplicandCrosswise<16, 32>, + cutlass::MatrixShape<16, 4>, + OpDelta::kRow, + kWarpSize>; +}; + +// Simt +template +struct DefaultWarpIteratorAFromSharedMemory< + WarpShape, + cutlass::gemm::GemmShape<1, 1, 1>, + RegularWarpIterator, + Policy> { + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + static constexpr auto kWarpSize = 32; + + // We just use the same iterator, as we reproduced the same shared-memory + // schema. Just modify it to handle non-complete tiles. + using WarpIterator = RegularWarpIterator; +}; + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass diff --git a/examples/41_fused_multi_head_attention/iterators/epilogue_predicated_tile_iterator.h b/examples/41_fused_multi_head_attention/iterators/epilogue_predicated_tile_iterator.h new file mode 100644 index 0000000000..64a58278fe --- /dev/null +++ b/examples/41_fused_multi_head_attention/iterators/epilogue_predicated_tile_iterator.h @@ -0,0 +1,751 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Epilogue iterator that supports prefetching + + Mostly copied from "cutlass/epilogue/threadblock/predicated_tile_iterator.h" +*/ + +#pragma once + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/threadblock/output_tile_thread_map.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator_params.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/transform/pitch_linear_thread_map.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +//////////////////////////////////////////////////////////////////////////////// + +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Tile iterator used to load and store output tile from global memory in +/// epilogue. +/// +/// Satisfies: ReadableTileIterator | PredicatedTileIterator | +/// ForwardTileIterator +/// +template < + typename ThreadMap_, ///< Thread map (conept: OutputTileThreadMap) + typename Element_, ///< Element data type + bool ScatterD = false, ///< Scatter D operand or not + bool UseCUDAStore = false> +class PredicatedTileIteratorPrefetch { + public: + using ThreadMap = ThreadMap_; + using Shape = typename ThreadMap::Shape; + + using Element = Element_; + + using Layout = layout::RowMajor; + using TensorRef = TensorRef; + using ConstTensorRef = typename TensorRef::ConstTensorRef; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + using TensorCoord = MatrixCoord; + + static int const kElementsPerAccess = ThreadMap::kElementsPerAccess; + static int const kThreads = ThreadMap::kThreads; + static int const kIterations = ThreadMap::Count::kTile; + + static_assert( + ThreadMap::Iterations::kRow > 0, + "ThreadMap::Iterations::kRow must be > 0"); + static_assert( + ThreadMap::Iterations::kGroup > 0, + "ThreadMap::Iterations::kGroup must be > 0"); + static_assert( + ThreadMap::Iterations::kCluster > 0, + "ThreadMap::Iterations::kCluster must be > 0"); + static_assert( + ThreadMap::Iterations::kColumn > 0, + "ThreadMap::Iterations::kColumn must be > 0"); + + /// Fragment object + using Fragment = Array< + Element, + ThreadMap::Iterations::kColumn * ThreadMap::Iterations::kRow * + ThreadMap::Iterations::kGroup * ThreadMap::Iterations::kCluster * + ThreadMap::kElementsPerAccess>; + + /// Memory access size + using AccessType = AlignedArray; + + // + // Parameters struct + // + + /// Uses a non-template class + struct Params : PredicatedTileIteratorParams { + using Base = PredicatedTileIteratorParams; + + CUTLASS_HOST_DEVICE + Params() {} + + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : PredicatedTileIteratorParams( + layout.stride(0) * int(sizeof(AccessType)) / kElementsPerAccess, + make_OutputTileThreadMapDesc()) {} + + CUTLASS_HOST_DEVICE + Params(Base const& base) : Base(base) {} + }; + + /// Mask object + struct Mask { + static int const kCount = ThreadMap::Iterations::kColumn; + + /// Predicate state + bool predicates[kCount]; + + // + // Mask + // + CUTLASS_HOST_DEVICE + Mask() { + enable(); + } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_HOST_DEVICE void clear() { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = false; + } + } + + ///< CUTLASS_HOST_DEVICE enables all accesses guarded by mask + CUTLASS_DEVICE void enable() { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kCount; ++i) { + predicates[i] = true; + } + } + }; + + private: + // + // Data members + // + + /// Parameters structure containing reference and precomputed state. + PredicatedTileIteratorParams params_; + + /// Byte-level pointer + uint8_t* byte_pointer_; + + /// Array of boolean values to contain steady-state predicates + Mask mask_; + + /// Extent of the matrix tile in rows + Index extent_row_; + + /// Extent of the matrix tile in rows + Index extent_column_; + + /// A thread's starting row position (assuming steady-state predicates have + /// been computed) + Index thread_start_row_; + + /// A thread's starting column + Index thread_start_column_; + + /// Internal state counter + int state_[3]; + + /// Scatter indices + int const* indices_; + + // + // Static asserts about internal strides + // + + static_assert(sizeof(extent_row_) == 4, "Expected 32b extents"); + static_assert(sizeof(thread_start_row_) == 4, "Expected 32b extents"); + static_assert( + sizeof(PredicatedTileIteratorParams::stride) == 8, + "Expected 64b strides"); + + private: + // + // Methods + // + + public: + // + // Methods + // + + /// Constructor + CUTLASS_DEVICE + PredicatedTileIteratorPrefetch( + PredicatedTileIteratorParams const& params, + Element* pointer, + TensorCoord extent, + int thread_idx, + TensorCoord threadblock_offset = TensorCoord(), + int const* indices = nullptr) + : params_(params), indices_(indices) { + TensorCoord thread_offset = + ThreadMap::initial_offset(thread_idx) + threadblock_offset; + + extent_row_ = extent.row(); + extent_column_ = extent.column(); + + thread_start_row_ = thread_offset.row(); + thread_start_column_ = thread_offset.column(); + + // Initialize predicates + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kColumn; ++c) { + mask_.predicates[c] = + ((thread_offset.column() + ThreadMap::Delta::kColumn * c) < + extent.column()); + } + + // Null pointer performs no accesses + if (!pointer) { + mask_.clear(); + } + + if (ScatterD && !indices) { + mask_.clear(); + } + + // Initialize pointer + byte_pointer_ = reinterpret_cast(pointer) + + LongIndex(thread_offset.row()) * LongIndex(params_.stride) + + LongIndex(thread_offset.column()) * sizeof(AccessType) / + kElementsPerAccess; + + if (ScatterD) { + byte_pointer_ = reinterpret_cast(pointer) + + LongIndex(thread_offset.column()) * sizeof(AccessType) / + kElementsPerAccess; + } + + // Initialize internal state counter + state_[0] = state_[1] = state_[2] = 0; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + byte_pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_DEVICE + void prefetch_all() { + CUTLASS_PRAGMA_UNROLL + for (int iter = 0; iter < kIterations; ++iter) { + prefetch(); + ++(*this); + } + } + + CUTLASS_DEVICE + void prefetch() { + uint8_t* byte_pointer = byte_pointer_; + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; + ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + AccessType* memory_pointer = + reinterpret_cast(byte_pointer); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; + ++column) { + // on windows using unsigned long here gives the error + // error: asm operand type size(4) does not match + // type/size implied by constraint 'l' + uint64_t addr = (uint64_t)((void*)&memory_pointer + [column * ThreadMap::Delta::kColumn / + kElementsPerAccess]); + asm volatile("prefetch.global.L1 [ %1 ];" : "=l"(addr) : "l"(addr)); + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + if (!ScatterD) { + byte_pointer += params_.increment_row; + } + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, int64_t byte_offset) const { + uint8_t* byte_pointer = byte_pointer_; + AccessType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; + ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + + ThreadMap::Iterations::kRow * + (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + AccessType* memory_pointer = + reinterpret_cast(byte_pointer + byte_offset); + + if (ScatterD && row_guard) { + assert(indices_); + + memory_pointer = reinterpret_cast( + byte_pointer + byte_offset + + LongIndex(indices_[row_offset + thread_start_row_]) * + LongIndex(params_.stride)); + } + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; + ++column) { + bool guard = row_guard && mask_.predicates[column]; + + cutlass::arch::global_load( + frag_ptr + [frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void*)&memory_pointer + [column * ThreadMap::Delta::kColumn / kElementsPerAccess], + guard); + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + if (!ScatterD) { + byte_pointer += params_.increment_row; + } + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) const { + load_with_byte_offset(frag, 0); + } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, int64_t byte_offset) const { + uint8_t* byte_pointer = byte_pointer_; + AccessType const* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; + ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + + ThreadMap::Iterations::kRow * + (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + AccessType* memory_pointer = + reinterpret_cast(byte_pointer + byte_offset); + + if (ScatterD && row_guard) { + assert(indices_); + + memory_pointer = reinterpret_cast( + byte_pointer + byte_offset + + LongIndex(indices_[row_offset + thread_start_row_]) * + LongIndex(params_.stride)); + } + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; + ++column) { + bool guard = row_guard && mask_.predicates[column]; + + if (UseCUDAStore) { + if (guard) { + memory_pointer + [column * ThreadMap::Delta::kColumn / kElementsPerAccess] = + frag_ptr + [frag_row_idx * ThreadMap::Iterations::kColumn + + column]; + } + } else { + cutlass::arch::global_store( + frag_ptr + [frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void*)&memory_pointer + [column * ThreadMap::Delta::kColumn / kElementsPerAccess], + guard); + } + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + if (!ScatterD) { + byte_pointer += params_.increment_row; + } + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Stores a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) const { + store_with_byte_offset(frag, 0); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void downsample_load_with_byte_offset( + Fragment& frag, + int64_t byte_offset, + int convolution_P, + int convolution_Q, + int add_P, + int add_Q, + int problem_N) const { + uint8_t* byte_pointer = byte_pointer_; + AccessType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; + ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + + ThreadMap::Iterations::kRow * + (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + int output_row = row_offset + thread_start_row_; + int output_N = output_row / (convolution_P * convolution_Q); + int output_PQ = output_row % (convolution_P * convolution_Q); + int output_P = output_PQ / convolution_Q; + int output_Q = output_PQ % convolution_Q; + + int input_row = output_N * 2 * convolution_P * 2 * convolution_Q + + (2 * output_P + add_P) * 2 * convolution_Q + 2 * output_Q + add_Q; + + int64_t byte_offset = + (input_row - output_row) * problem_N * sizeof(float); + + AccessType* memory_pointer = + reinterpret_cast(byte_pointer + byte_offset); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; + ++column) { + bool guard = row_guard && mask_.predicates[column]; + + cutlass::arch::global_load( + frag_ptr + [frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void*)&memory_pointer + [column * ThreadMap::Delta::kColumn / kElementsPerAccess], + guard); + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + byte_pointer += params_.increment_row; + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void upsample_load_with_byte_offset( + Fragment& frag, + int64_t byte_offset, + int convolution_P, + int convolution_Q, + int add_P, + int add_Q, + int problem_N) const { + uint8_t* byte_pointer = byte_pointer_; + AccessType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int cluster = 0; cluster < ThreadMap::Iterations::kCluster; + ++cluster) { + CUTLASS_PRAGMA_UNROLL + for (int group = 0; group < ThreadMap::Iterations::kGroup; ++group) { + CUTLASS_PRAGMA_UNROLL + for (int row = 0; row < ThreadMap::Iterations::kRow; ++row) { + int frag_row_idx = + (row + + ThreadMap::Iterations::kRow * + (group + ThreadMap::Iterations::kGroup * cluster)); + + int row_offset = row * ThreadMap::Delta::kRow + + group * ThreadMap::Delta::kGroup + + cluster * ThreadMap::Delta::kCluster; + + bool row_guard = ((row_offset + thread_start_row_) < extent_row_); + + int output_row = row_offset + thread_start_row_; + int output_N = output_row / (convolution_P * convolution_Q); + int output_PQ = output_row % (convolution_P * convolution_Q); + int output_P = output_PQ / convolution_Q; + int output_Q = output_PQ % convolution_Q; + int row_add_P = add_P; + int row_add_Q = add_Q; + if (output_P > convolution_P - 2) + row_add_P = 0; + if (output_Q > convolution_Q - 2) + row_add_Q = 0; + + int input_row = output_N * (convolution_P / 2) * (convolution_Q / 2) + + ((output_P + row_add_P) / 2) * (convolution_Q / 2) + + (output_Q + row_add_Q) / 2; + + int64_t byte_offset = + (input_row - output_row) * problem_N * sizeof(float); + + AccessType* memory_pointer = + reinterpret_cast(byte_pointer + byte_offset); + + CUTLASS_PRAGMA_UNROLL + for (int column = 0; column < ThreadMap::Iterations::kColumn; + ++column) { + bool guard = row_guard && mask_.predicates[column]; + + cutlass::arch::global_load( + frag_ptr + [frag_row_idx * ThreadMap::Iterations::kColumn + column], + (void*)&memory_pointer + [column * ThreadMap::Delta::kColumn / kElementsPerAccess], + guard); + } + + if (row + 1 < ThreadMap::Iterations::kRow) { + byte_pointer += params_.increment_row; + } + } + + if (group + 1 < ThreadMap::Iterations::kGroup) { + byte_pointer += params_.increment_group; + } + } + + if (cluster + 1 < ThreadMap::Iterations::kCluster) { + byte_pointer += params_.increment_cluster; + } + } + } + + CUTLASS_DEVICE + MatrixCoord thread_start() const { + return MatrixCoord(thread_start_row_, thread_start_column_); + } + + /// Need to get the thread start row from the tile iterator + CUTLASS_DEVICE + int32_t thread_start_row() const { + return thread_start_row_; + } + + /// Need to get the thread start row from the tile iterator + CUTLASS_DEVICE + int32_t thread_start_column() const { + return thread_start_column_; + } + + /// Extent of the matrix in rows + CUTLASS_DEVICE + Index extent_row() const { + return extent_row_; + } + + /// Extent of the matrix in columns + CUTLASS_DEVICE + Index extent_column() const { + return extent_column_; + } + + /// Advances to the next position to load or store + CUTLASS_HOST_DEVICE + PredicatedTileIteratorPrefetch& operator++() { + ++state_[0]; + + if (!ScatterD) { + byte_pointer_ += params_.advance_row; + } + + thread_start_row_ += ThreadMap::Shape::kRow; + + if (state_[0] == ThreadMap::Count::kRow) { + state_[0] = 0; + ++state_[1]; + byte_pointer_ += params_.advance_group; + + thread_start_row_ += (ThreadMap::Shape::kGroup - 1) * + ThreadMap::Shape::kRow * ThreadMap::Count::kRow; + + if (state_[1] == ThreadMap::Count::kGroup) { + state_[1] = 0; + ++state_[2]; + byte_pointer_ += params_.advance_cluster; + + thread_start_row_ += ThreadMap::Count::kGroup * + ThreadMap::Shape::kGroup * ThreadMap::Count::kRow * + ThreadMap::Shape::kRow; + + if (state_[2] == ThreadMap::Count::kCluster) { + state_[2] = 0; + byte_pointer_ += params_.advance_tile; + } + } + } + + return *this; + } + + ///< Efficiently disables all accesses guarded by mask + CUTLASS_DEVICE void clear_mask() { + mask_.clear(); + } + + ///< Efficiently enables all accesses guarded by mask + CUTLASS_DEVICE void enable_mask() { + mask_.enable(); + } + + ///< Sets the mask + CUTLASS_DEVICE void get_mask(Mask& mask) const { + mask = mask_; + } + + ///< Sets the mask + CUTLASS_DEVICE void set_mask(Mask const& mask) { + mask_ = mask; + } +}; + +template +struct MakePrefetchableIterator { + using Iterator = PredicatedTileIteratorPrefetch< + typename IT::ThreadMap, + typename IT::Element>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/examples/41_fused_multi_head_attention/iterators/make_residual_last.h b/examples/41_fused_multi_head_attention/iterators/make_residual_last.h new file mode 100644 index 0000000000..845a3c6b7a --- /dev/null +++ b/examples/41_fused_multi_head_attention/iterators/make_residual_last.h @@ -0,0 +1,97 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "predicated_tile_access_iterator_residual_last.h" +#include "predicated_tile_iterator_residual_last.h" + +namespace cutlass { +namespace transform { +namespace threadblock { + +template +struct MakeIteratorResidualLast; + +template < + typename Shape, + typename Element, + typename Layout, + int AdvanceRank, + typename ThreadMap, + int AccessSize, + bool Gather> +struct MakeIteratorResidualLast> { + using Iterator = PredicatedTileIteratorResidualLast< + Shape, + Element, + Layout, + AdvanceRank, + ThreadMap, + AccessSize, + Gather>; +}; + +template < + typename Shape, + typename Element, + typename Layout, + int AdvanceRank, + typename ThreadMap, + typename AccessType, + bool Gather> +struct MakeIteratorResidualLast> { + using Iterator = PredicatedTileAccessIteratorResidualLast< + Shape, + Element, + Layout, + AdvanceRank, + ThreadMap, + AccessType, + Gather>; +}; +} // namespace threadblock +} // namespace transform +} // namespace cutlass diff --git a/examples/41_fused_multi_head_attention/iterators/predicated_tile_access_iterator_residual_last.h b/examples/41_fused_multi_head_attention/iterators/predicated_tile_access_iterator_residual_last.h new file mode 100644 index 0000000000..6bc9e52c3c --- /dev/null +++ b/examples/41_fused_multi_head_attention/iterators/predicated_tile_access_iterator_residual_last.h @@ -0,0 +1,2114 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates calculating the address and predicates to the load of tiles + from pitch-linear rank=2 tensors. + + This iterator uses masks to guard out-of-bounds accesses. The first tile + this iterator visits maybe partial, then the remaining tiles are complete. + So, we only need to compute the predicates twice, once before the first tile + and once for the remaining full tiles which can share the same predicates. + + A precomputed "Params" object minimizes the amount of state that must be + stored in registers, and integer addition is used to advance the pointer + through memory. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/cutlass.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/transform/threadblock/predicated_tile_access_iterator_params.h" + +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// PredicatedTileAccessIteratorResidualLast +/// +template < + typename Shape, + typename Element, + typename Layout, + int AdvanceRank, + typename ThreadMap, + typename AccessType, + bool Gather = false> +class PredicatedTileAccessIteratorResidualLast; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for pitch-linear +/// data. +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + typename AccessType_, + bool Gather> +class PredicatedTileAccessIteratorResidualLast< + Shape_, + Element_, + layout::PitchLinear, + AdvanceRank, + ThreadMap_, + AccessType_, + Gather> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::PitchLinear; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingPredicates = PredicatedTileAccessIteratorPredicates< + Shape, + Element, + Layout, + AdvanceRank, + ThreadMap, + AccessType>; + + static int const kAccessesPerVector = + ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert( + !(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + using Mask = typename UnderlyingPredicates::Mask; + + /// Uses a non-template class + struct Params : PredicatedTileAccessIteratorParams { + using Base = PredicatedTileAccessIteratorParams; + + // Default ctor + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : Base( + layout.stride(0), + MakePredicatedTileAccessIteratorDesc< + Shape, + Element, + Layout, + kAdvanceRank, + ThreadMap>()()) {} + + CUTLASS_HOST_DEVICE + Params(Base const& base) : Base(base) {} + }; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char*; + + private: + // + // Data members + // + + UnderlyingPredicates the_predicates; + Mask residual_tile_mask; + + /// Parameters object with precomputed internal state + Params params_; + + /// Internal pointer to first access of tile + BytePointer pointer_; + + /// Below is used when Gather is turned on. We need to record strided_offset + /// and contiguous_offset separated to compute the offset by using + /// + /// offset = contiguous_offset + indices[strided_offset] + /// + + /// Gather indices + int const* indices_; + + Index gather_offset_strided; + + private: + /// Computes predicates based on internally tracked per-thread offset. + CUTLASS_DEVICE + void compute_predicates_( + /// Extent of the matrix window + TensorCoord extent, + /// optionally, simplify predicate calculation during 'steady state' phase + bool is_steady_state = false) { + the_predicates.compute_predicates_(extent, is_steady_state); + } + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + /// Gather indices + int const* indices = nullptr) + : params_(params), + pointer_(reinterpret_cast( + const_cast(pointer))), + the_predicates(extent), + indices_(indices) { + the_predicates.set_predicates(thread_id, threadblock_offset); + the_predicates.get_mask(residual_tile_mask); + + // Working around a weird compiler bug happening on P100 for the backward. + // I've seen together: the_predicates.predicates_[0] = 14 (instead of 15) + // residual_tile_mask[0] = 15 (correct) + // + // Adding prints when the value is calculated (in `compute_predicates_`) + // sometimes removes the bug. The consequence is that we skip some + // element of a tensor, leading to wrong results + // Setting `compute_predicates_`'s second argument (`is_steady_state`) to + // true also seems to get rid of the bug - at the cost of twice as many + // comparisons. +#if !defined(__CUDA_ARCH__) || (__CUDA_ARCH__ >= 700) + constexpr bool kWorkAroundCompilerBug = false; +#else + constexpr bool kWorkAroundCompilerBug = true; +#endif + the_predicates.compute_predicates_(extent, true && !kWorkAroundCompilerBug); + + // update internal pointers + Layout layout(params_.stride_); + + if (!Gather) { + add_pointer_offset(layout(the_predicates.thread_offset_)); + } else { + gather_offset_strided = the_predicates.thread_offset_.strided(); + add_pointer_offset( + layout(make_Coord(the_predicates.thread_offset_.contiguous(), 0))); + } + } + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id) + : PredicatedTileAccessIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + the_predicates.set_iteration_index(index); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool is_residual_tile) { + if (is_residual_tile) { + the_predicates.set_mask(residual_tile_mask); + } + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += sizeof_bits::value * pointer_offset / 8; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + if (!Gather) { + if (kAdvanceRank) { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset.strided()); + pointer_ += Shape::kContiguous * tile_offset.contiguous(); + } else { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset.contiguous()); + pointer_ += Shape::kStrided * tile_offset.strided(); + } + } else { + add_pointer_offset(Shape::kContiguous * tile_offset.contiguous()); + gather_offset_strided += Shape::kStrided * tile_offset.strided(); + } + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { + if (Gather) { + assert(indices_); + + if (!valid()) { + return nullptr; + } + + LongIndex contiguous_offset = the_predicates.iteration_contiguous_ * + (ThreadMap::Delta::kContiguous * sizeof_bits::value / + 8) + + the_predicates.iteration_vector_; + int strided_index = gather_offset_strided + + the_predicates.iteration_strided_ * ThreadMap::Delta::kStrided; + + LongIndex strided_offset = indices_[strided_index] * + LongIndex(params_.stride_) * sizeof_bits::value / 8; + + return reinterpret_cast( + pointer_ + contiguous_offset + strided_offset); + } + + return reinterpret_cast( + pointer_ + + the_predicates.iteration_contiguous_ * + (ThreadMap::Delta::kContiguous * + sizeof_bits::value) / + 8) + + the_predicates.iteration_vector_; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() { + the_predicates.operator++(); + + ++the_predicates.iteration_vector_; + if (the_predicates.iteration_vector_ < kAccessesPerVector) { + return *this; + } + + the_predicates.iteration_vector_ = 0; + ++the_predicates.iteration_contiguous_; + + if (the_predicates.iteration_contiguous_ < + ThreadMap::Iterations::kContiguous) { + return *this; + } + + // Enter here only if (iteration_contiguous_ == + // ThreadMap::Iteration::kContiguous) + the_predicates.iteration_contiguous_ = 0; + ++the_predicates.iteration_strided_; + + if (the_predicates.iteration_strided_ < ThreadMap::Iterations::kStrided) { + if (!Gather) { + pointer_ += params_.inc_strided_; + } + + return *this; + } + + // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) + // which means we enter the next tile. + the_predicates.iteration_strided_ = 0; + + if (!Gather) { + // advance to next tile + pointer_ += params_.inc_next_; + + // now return to start tile - if the iterator is subsequently advanced, + // this subtraction as well as the subsequent integer addition are both + // elided by the compiler. + pointer_ -= params_.inc_advance_; + } + + return *this; + } + + /// Increment and return an instance to self. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + the_predicates.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + the_predicates.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + the_predicates.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + the_predicates.get_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() const { + return the_predicates.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for column-major +/// data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + typename AccessType_, + bool Gather> +class PredicatedTileAccessIteratorResidualLast< + Shape_, + Element_, + layout::ColumnMajor, + AdvanceRank, + ThreadMap_, + AccessType_, + Gather> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessType, + Gather>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + /// Default ctor + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::PitchLinear(layout.stride(0))){}; + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + ///< Precomputed parameters object + Params const& params, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.row(), extent.column()), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.row(), + threadblock_offset.column()), + indices) {} + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + iterator_.set_iteration_index(index); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for row-major +/// data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + typename AccessType_, + bool Gather> +class PredicatedTileAccessIteratorResidualLast< + Shape_, + Element_, + layout::RowMajor, + AdvanceRank, + ThreadMap_, + AccessType_, + Gather> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessType, + Gather>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + /// Default ctor + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::PitchLinear(layout.stride(0))){}; + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + ///< Precomputed parameters object + Params const& params, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const& threadblock_offset, + /// Gather indices + int const* indices = nullptr) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.column(), + threadblock_offset.row()), + indices) {} + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + iterator_.set_iteration_index(index); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for affine rank 2 +/// data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + typename AccessType_> +class PredicatedTileAccessIteratorResidualLast< + Shape_, + Element_, + layout::AffineRankN<2>, + AdvanceRank, + ThreadMap_, + AccessType_, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRankN<2>; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingPredicates = PredicatedTileAccessIteratorPredicates< + Shape, + Element, + layout::PitchLinear, + AdvanceRank, + ThreadMap, + AccessType>; + + static int const kAccessesPerVector = + ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert( + !(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingPredicates::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + public: + friend PredicatedTileAccessIteratorResidualLast; + + private: + /// stride of pitch-linear layout (units of Element) + Coord stride_; + /// amount (in byte) to increment pointer to move to next access along + /// contiguous dimension + LongIndex inc_contiguous_; + /// amount (in byte) to increment pointer from first access of current + /// contiguous dimension to first access of next one. + LongIndex inc_strided_; + /// amount (in byte) to increment pointer from last access of current + /// contiguous dimension to first access of next one. + LongIndex inc_next_strided_; + /// amount (in byte) to increment pointer from last access to first access + /// of next tile + LongIndex inc_next_; + /// amount (in byte) to increment pointer from first access of current tile + /// to first access of next tile + LongIndex inc_advance_; + + public: + // Default ctor + CUTLASS_HOST_DEVICE + Params() + : stride_(0), + inc_contiguous_(0), + inc_strided_(0), + inc_next_(0), + inc_advance_(0) {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : stride_({layout.stride(0), layout.stride(1)}) { + inc_contiguous_ = + (LongIndex(stride_[0]) * ThreadMap::Delta::kContiguous) * + sizeof_bits::value / 8; + + inc_strided_ = (LongIndex(stride_[1]) * ThreadMap::Delta::kStrided) * + sizeof_bits::value / 8; + + inc_next_strided_ = inc_strided_ - + LongIndex(ThreadMap::Iterations::kContiguous - 1) * inc_contiguous_; + + if (kAdvanceRank) { + // advance along strided dimension + inc_advance_ = Shape::kStrided * LongIndex(stride_[1]) * + sizeof_bits::value / 8; + } else { + // advance along contiguous dimension + inc_advance_ = + Shape::kContiguous * stride_[0] * sizeof_bits::value / 8; + } + + inc_next_ = inc_advance_ - + LongIndex(ThreadMap::Iterations::kContiguous - 1) * inc_contiguous_ - + LongIndex(ThreadMap::Iterations::kStrided - 1) * inc_strided_; + }; + }; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char*; + + // + // Data members + // + + /// Parameters object with precomputed internal state + Params params_; + + /// Internal pointer to first access of tile + BytePointer pointer_; + + UnderlyingPredicates the_predicates; + Mask residual_tile_mask; + + private: + /// Computes predicates based on internally tracked per-thread offset. + CUTLASS_DEVICE + void compute_predicates_( + /// Extent of the matrix window + TensorCoord extent, + /// optionally, simplify predicate calculation during 'steady state' phase + bool is_steady_state = false) { + the_predicates.compute_predicates_(extent, is_steady_state); + } + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + ///< Precomputed parameters object + Params const& params, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : params_(params), + pointer_(reinterpret_cast( + const_cast(pointer))), + the_predicates(extent) { + the_predicates.set_predicates(thread_id, threadblock_offset); + + // update internal pointers + Layout layout(params_.stride_); + add_pointer_offset(layout(the_predicates.thread_offset_)); + } + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + the_predicates.set_iteration_index(index); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool is_residual_tile) { + if (is_residual_tile) { + the_predicates.set_mask(residual_tile_mask); + } + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += sizeof_bits::value * pointer_offset / 8; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + if (kAdvanceRank) { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset[1]); + pointer_ += Shape::kContiguous * tile_offset[0]; + } else { + pointer_ += params_.inc_advance_ * LongIndex(tile_offset[0]); + pointer_ += Shape::kStrided * tile_offset[1]; + } + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { + return reinterpret_cast(pointer_) + + the_predicates.iteration_vector_; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() { + the_predicates.operator++(); + ++the_predicates.iteration_vector_; + if (the_predicates.iteration_vector_ < kAccessesPerVector) { + return *this; + } + + the_predicates.iteration_vector_ = 0; + ++the_predicates.iteration_contiguous_; + + if (the_predicates.iteration_contiguous_ < + ThreadMap::Iterations::kContiguous) { + pointer_ += params_.inc_contiguous_; + return *this; + } + + // Enter here only if (iteration_contiguous_ == + // ThreadMap::Iteration::kContiguous) + the_predicates.iteration_contiguous_ = 0; + ++the_predicates.iteration_strided_; + + if (the_predicates.iteration_strided_ < ThreadMap::Iterations::kStrided) { + pointer_ += params_.inc_next_strided_; + return *this; + } + + // Enter here only if (iteration_stride_ == ThreadMap::Iteration::kStrided) + // which means we enter the next tile. + the_predicates.iteration_strided_ = 0; + + // advance to next tile + pointer_ += params_.inc_next_; + + // now return to start tile - if the iterator is subsequently advanced, this + // subtraction as well as the subsequent integer addition are both elided by + // the compiler. + pointer_ -= params_.inc_advance_; + + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + the_predicates.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + the_predicates.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + the_predicates.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + the_predicates.get_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return the_predicates.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for affine rank 2 +/// column-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + typename AccessType_> +class PredicatedTileAccessIteratorResidualLast< + Shape_, + Element_, + layout::AffineRank2ColumnMajor, + AdvanceRank, + ThreadMap_, + AccessType_, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRank2ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + // Map to the underlying AffineRankN<2> layout + using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::AffineRankN<2>, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessType>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + /// Default ctor + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given an AffineRankN<2> tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::AffineRankN<2>(layout.stride(0), layout.stride(1))){}; + }; + + private: + // + // Data members + // + + /// Underlying AffineRankN<2> tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + ///< Precomputed parameters object + Params const& params, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.row(), extent.column()), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.row(), + threadblock_offset.column())) {} + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + iterator_.set_iteration_index(index); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + iterator_.add_tile_offset( + make_Coord(tile_offset.row(), tile_offset.column())); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for affine rank-2 +/// row-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + typename AccessType_> +class PredicatedTileAccessIteratorResidualLast< + Shape_, + Element_, + layout::AffineRank2RowMajor, + AdvanceRank, + ThreadMap_, + AccessType_, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRank2RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + // Map to the underlying AffineRankN<2> layout + using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::AffineRankN<2>, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessType>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + /// Default ctor + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given an AffineRankN<2> tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::AffineRankN<2>(layout.stride(1), layout.stride(0))){}; + }; + + private: + // + // Data members + // + + /// Underlying AffineRankN<2> tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + ///< Precomputed parameters object + Params const& params, + ///< Pointer to start of tensor + Pointer pointer, + ///< Extent of tensor + TensorCoord extent, + ///< ID of each participating thread + int thread_id, + ///< Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.column(), + threadblock_offset.row())) {} + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + iterator_.set_iteration_index(index); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + iterator_.add_tile_offset( + make_Coord(tile_offset.column(), tile_offset.row())); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for column-major +/// interleaved data. It is mapped to the congruous layout. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// + +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + typename AccessType_, + int InterleavedK> +class PredicatedTileAccessIteratorResidualLast< + Shape_, + Element_, + layout::ColumnMajorInterleaved, + AdvanceRank, + ThreadMap_, + AccessType_, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + static int const kInterleavedK = InterleavedK; + using Layout = layout::ColumnMajorInterleaved; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< + layout::PitchLinearShape< + Shape::kRow * kInterleavedK, + Shape::kColumn / kInterleavedK>, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessType>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord( + extent.row() * kInterleavedK, + extent.column() / kInterleavedK), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.row() * kInterleavedK, + threadblock_offset.column() / kInterleavedK)) {} + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + iterator_.set_iteration_index(index); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + iterator_.add_tile_offset({tile_offset.row(), tile_offset.column()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileAccessIteratorResidualLast for row-major +/// interleaved data. +// It is mapped to the congruous layout. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + typename AccessType_, + int InterleavedK> +class PredicatedTileAccessIteratorResidualLast< + Shape_, + Element_, + layout::RowMajorInterleaved, + AdvanceRank, + ThreadMap_, + AccessType_, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + static int const kInterleavedK = InterleavedK; + using Layout = layout::RowMajorInterleaved; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileAccessIteratorResidualLast< + layout::PitchLinearShape< + Shape::kColumn * kInterleavedK, + Shape::kRow / kInterleavedK>, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessType>; + + static int const kAccessesPerVector = UnderlyingIterator::kAccessesPerVector; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileAccessIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord( + extent.column() * kInterleavedK, + extent.row() / kInterleavedK), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.column() * kInterleavedK, + threadblock_offset.row() / kInterleavedK)) {} + + /// Construct a PredicatedTileAccessIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileAccessIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(int index) { + iterator_.set_iteration_index(index); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + void add_tile_offset(TensorCoord const& tile_offset) { + iterator_.add_tile_offset({tile_offset.column(), tile_offset.row()}); + } + + /// Returns a pointer + CUTLASS_HOST_DEVICE + AccessType* get() const { + return reinterpret_cast(iterator_.get()); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileAccessIteratorResidualLast operator++(int) { + PredicatedTileAccessIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Returns whether access is valid or not + CUTLASS_HOST_DEVICE + bool valid() { + return iterator_.valid(); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/examples/41_fused_multi_head_attention/iterators/predicated_tile_iterator_residual_last.h b/examples/41_fused_multi_head_attention/iterators/predicated_tile_iterator_residual_last.h new file mode 100644 index 0000000000..4db56560fc --- /dev/null +++ b/examples/41_fused_multi_head_attention/iterators/predicated_tile_iterator_residual_last.h @@ -0,0 +1,2119 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of tiles from pitch-linear rank=2 + tensors. + + This iterator uses masks to guard out-of-bounds accesses. The first tile + this iterator visits maybe partial, then the remaining tiles are complete. + So, we only need to compute the predicates twice, once before the first tile + and once for the remaining full tiles which can share the same predicates. + + A precomputed "Params" object minimizes the amount of state that must be + stored in registers, and integer addition is used to advance the pointer + through memory. +*/ + +#pragma once + +#include "cutlass/arch/memory.h" +#include "cutlass/transform/threadblock/predicated_tile_access_iterator.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace transform { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// PredicatedTileIteratorResidualLast +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +/// Regular tile iterator using a precomputed control structure to minimize +/// register liveness and integer arithmetic. +/// +/// Layout is assumed to be invariant at the time the precomputed "Params" +/// object is constructed. +/// +/// Base pointer and tensor extents may be specified at the time the iterator is +/// constructed. Subsequently, they are assumed to be immutable. +/// +/// Adding a logical coordinate offset may be performed at the time the iterator +/// is constructed. Subsequent additions to logical coordinate offset may be +/// performed but are relatively expensive. +/// +/// Visitation order is intended to first visit a "residual" tile that may be +/// partially full in both the advance dimension and the steady-state dimension. +/// This is assumed to be the last tile in the iteration sequence. Advancing an +/// iterator that has just been constructed moves to the first tile that is full +/// in the advance dimension and recomputes predicates. Subsequent accesses may +/// be performed without updating internal predicates and are efficient in terms +/// of live register state and pointer arithmetic instructions. +/// +/// To be efficient, this assumes the iterator will be dereferenced and advanced +/// at least once outside any looping structure to minimize integer arithmetic. +/// +/// Acceses out of bounds are safe so long as `clear_mask()` is called prior to +/// dereferencing the iterator. +/// +/// +/// Example: +/// +/// An efficient pipeline structure may be constructed as follows: +/// +// template +// __global__ void kernel( +// typename Iterator::Params params, +// typename Iterator::Element *ptr, +// TensorCoord extent) { +// +// typename Iterator::Fragment fragment; +// +// TensorCoord threadblock_offset(0, 0); +// +// Iterator iter(params, ptr, extent, threadIdx.x, threadblock_offsets); +// +// +// fragment = *iter; // load "residue" tile first +// ++iter; // advance to first "steady state" tile and update +// internal masks +// +// +// #pragma unroll +// for (int i = Remaining - 1; i >= 0; --i) { +// +// f(fragment); +// +// if (!i) { +// iter.clear_mask(); // light-weight operation to clear masks - +// subsequent loads become NO-OPs. +// } +// +// fragment = *iter; // load tile during "steady state" phase +// ++iter; // advance to next tile - lightweight due to +// steady-state masks +// } +// } +// +// void host(TensorView view) { +// +// using Iterator = +// transform::threadblock::PredicatedTileIteratorResidualLast; +// +// typename Iterator::Params params(view.layout()); +// +// kernel(params, view.data()); +// } +/// +/// +template < + typename Shape, + typename Element, + typename Layout, + int AdvanceRank, + typename ThreadMap, + int AccessSize = ThreadMap::kElementsPerAccess, + bool Gather = false> +class PredicatedTileIteratorResidualLast; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for pitch-linear data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize, + bool Gather> +class PredicatedTileIteratorResidualLast< + Shape_, + Element_, + layout::PitchLinear, + AdvanceRank, + ThreadMap_, + AccessSize, + Gather> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::PitchLinear; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + /// Type used for internal memory accesses + using AccessType = AlignedArray< + Element, + AccessSize, + (AccessSize * sizeof_bits::value / 8)>; + + /// Underlying iterator to compute the addresses + using TileAccessIterator = PredicatedTileAccessIteratorResidualLast< + Shape, + Element, + Layout, + kAdvanceRank, + ThreadMap, + AccessType, + Gather>; + + static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array< + Element, + ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename TileAccessIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + public: + using Base = typename TileAccessIterator::Params::Base; + + friend PredicatedTileIteratorResidualLast; + + private: + /// Parameters object + typename TileAccessIterator::Params params_; + + public: + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) : params_(layout) {} + + CUTLASS_HOST_DEVICE + Params() {} + + CUTLASS_HOST_DEVICE + Params(Base const& base) : params_(base) {} + }; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char*; + + private: + // + // Data members + // + + /// Data member to the tile access iterator + TileAccessIterator address_iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + /// Gather indices + int const* indices = nullptr) + : address_iterator_( + params.params_, + pointer, + extent, + thread_id, + threadblock_offset, + indices) {} + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + address_iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() { + if (kAdvanceRank) + address_iterator_.add_tile_offset({0, 1}); + else + address_iterator_.add_tile_offset({1, 0}); + + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + address_iterator_.clear_mask(enable); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + address_iterator_.set_residual_tile(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + address_iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + address_iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + address_iterator_.get_mask(mask); + } + + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { + load_with_byte_offset( + frag, pointer_offset * sizeof_bits::value / 8); + } + + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { + AccessType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + int idx = v + + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); + + address_iterator_.set_iteration_index(idx); + char const* byte_ptr = + reinterpret_cast(address_iterator_.get()) + + byte_offset; + + AccessType const* access_ptr = + reinterpret_cast(byte_ptr); + + cutlass::arch::global_load( + frag_ptr[idx], access_ptr, address_iterator_.valid()); + + ++address_iterator_; + } + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { + load_with_byte_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { + store_with_byte_offset( + frag, pointer_offset * sizeof_bits::value / 8); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { + address_iterator_.set_iteration_index(0); + AccessType const* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + int idx = v + + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); + + char* byte_ptr = + reinterpret_cast(address_iterator_.get()) + byte_offset; + AccessType* access_ptr = reinterpret_cast(byte_ptr); + + if (address_iterator_.valid()) { + *access_ptr = frag_ptr[idx]; + } + ++address_iterator_; + } + } + } + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { + store_with_byte_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for pitch-linear data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize, + bool Gather> +class PredicatedTileIteratorResidualLast< + Shape_, + Element_, + layout::ColumnMajor, + AdvanceRank, + ThreadMap_, + AccessSize, + Gather> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessSize, + Gather>; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array< + Element, + ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + TensorCoord const& threadblock_offset, ///< Initial offset of threadblock + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.row(), extent.column()), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.row(), + threadblock_offset.column()), + indices) {} + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { + iterator_.store_with_byte_offset(frag, byte_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for pitch-linear data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize, + bool Gather> +class PredicatedTileIteratorResidualLast< + Shape_, + Element_, + layout::RowMajor, + AdvanceRank, + ThreadMap_, + AccessSize, + Gather> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessSize, + Gather>; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array< + Element, + ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + TensorCoord const& threadblock_offset, ///< Initial offset of threadblock + int const* indices = nullptr ///< Gather indices + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.column(), + threadblock_offset.row()), + indices) {} + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { + iterator_.store_with_byte_offset(frag, byte_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for affine rank-2 data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize> +class PredicatedTileIteratorResidualLast< + Shape_, + Element_, + layout::AffineRankN<2>, + AdvanceRank, + ThreadMap_, + AccessSize, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRankN<2>; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + /// Type used for internal memory accesses + using AccessType = AlignedArray< + Element, + AccessSize, + (AccessSize * sizeof_bits::value / 8)>; + + /// Underlying iterator to compute the addresses + using TileAccessIterator = PredicatedTileAccessIteratorResidualLast< + Shape, + Element, + Layout, + kAdvanceRank, + ThreadMap, + AccessType>; + + static int const kAccessesPerVector = TileAccessIterator::kAccessesPerVector; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array< + Element, + ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename TileAccessIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + public: + friend PredicatedTileIteratorResidualLast; + + private: + /// Parameters object + typename TileAccessIterator::Params params_; + + public: + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) : params_(layout) {} + + CUTLASS_HOST_DEVICE + Params() {} + }; + + private: + /// Internal pointer type permits fast address arithmetic + using BytePointer = char*; + + private: + // + // Data members + // + + /// Data member to the tile access iterator + TileAccessIterator address_iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : address_iterator_( + params.params_, + pointer, + extent, + thread_id, + threadblock_offset) {} + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + address_iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() { + if (kAdvanceRank) + address_iterator_.add_tile_offset(make_Coord(0, 1)); + else + address_iterator_.add_tile_offset(make_Coord(1, 0)); + + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + address_iterator_.clear_mask(enable); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + address_iterator_.set_residual_tile(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + address_iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + address_iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + address_iterator_.get_mask(mask); + } + + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { + load_with_byte_offset( + frag, pointer_offset * sizeof_bits::value / 8); + } + + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { + AccessType* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + int idx = v + + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); + + address_iterator_.set_iteration_index(idx); + char const* byte_ptr = + reinterpret_cast(address_iterator_.get()) + + byte_offset; + + AccessType const* access_ptr = + reinterpret_cast(byte_ptr); + + cutlass::arch::global_load( + frag_ptr[idx], access_ptr, address_iterator_.valid()); + + ++address_iterator_; + } + } + } + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { + load_with_byte_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { + store_with_byte_offset( + frag, pointer_offset * sizeof_bits::value / 8); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { + address_iterator_.set_iteration_index(0); + AccessType const* frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + CUTLASS_PRAGMA_UNROLL + for (int c = 0; c < ThreadMap::Iterations::kContiguous; ++c) { + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < kAccessesPerVector; ++v) { + int idx = v + + kAccessesPerVector * (c + s * ThreadMap::Iterations::kContiguous); + + char* byte_ptr = + reinterpret_cast(address_iterator_.get()) + byte_offset; + AccessType* access_ptr = reinterpret_cast(byte_ptr); + + if (address_iterator_.valid()) { + *access_ptr = frag_ptr[idx]; + } + ++address_iterator_; + } + } + } + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { + store_with_byte_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for affine rank 2 +/// column-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize> +class PredicatedTileIteratorResidualLast< + Shape_, + Element_, + layout::AffineRank2ColumnMajor, + AdvanceRank, + ThreadMap_, + AccessSize, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRank2ColumnMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + // Map to the underlying AffineRankN<2> layout + using UnderlyingIterator = PredicatedTileIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::AffineRankN<2>, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessSize>; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array< + Element, + ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given an AffineRankN<2> tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::AffineRankN<2>(layout.stride(0), layout.stride(1))) {} + }; + + private: + // + // Data members + // + + /// Underlying AffineRankN<2> tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + TensorCoord const& threadblock_offset, ///< Initial offset of threadblock + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.row(), extent.column()), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.row(), + threadblock_offset.column())) {} + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { + iterator_.store_with_byte_offset(frag, byte_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for affine rank 2 +/// row-major data. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize> +class PredicatedTileIteratorResidualLast< + Shape_, + Element_, + layout::AffineRank2RowMajor, + AdvanceRank, + ThreadMap_, + AccessSize, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + using Layout = layout::AffineRank2RowMajor; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + // Map to the underlying AffineRankN<2> layout + using UnderlyingIterator = PredicatedTileIteratorResidualLast< + layout::PitchLinearShape, + Element, + layout::AffineRankN<2>, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessSize>; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array< + Element, + ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given an AffineRankN<2> tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::AffineRankN<2>(layout.stride(1), layout.stride(0))) {} + }; + + private: + // + // Data members + // + + /// Underlying AffineRankN<2> tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id, ///< ID of each participating thread + TensorCoord const& threadblock_offset, ///< Initial offset of threadblock + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord(extent.column(), extent.row()), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.column(), + threadblock_offset.row())) {} + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_byte_offset(Fragment& frag, LongIndex byte_offset) { + iterator_.load_with_byte_offset(frag, byte_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_byte_offset(Fragment const& frag, LongIndex byte_offset) { + iterator_.store_with_byte_offset(frag, byte_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for interleaved data. +/// It is mapped to the congruous layout. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// + +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize, + int InterleavedK> +class PredicatedTileIteratorResidualLast< + Shape_, + Element_, + layout::ColumnMajorInterleaved, + AdvanceRank, + ThreadMap_, + AccessSize, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + static int const kInterleavedK = InterleavedK; + using Layout = layout::ColumnMajorInterleaved; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileIteratorResidualLast< + layout::PitchLinearShape< + Shape::kRow * kInterleavedK, + Shape::kColumn / kInterleavedK>, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 0 : 1), + ThreadMap, + AccessSize>; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array< + Element, + ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord( + extent.row() * kInterleavedK, + extent.column() / kInterleavedK), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.row() * kInterleavedK, + threadblock_offset.column() / kInterleavedK)) {} + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Specialization of PredicatedTileIteratorResidualLast for interleaved-32 +/// data. It is mapped to the congruous layout. +/// +/// Satisfies: ForwardTileIteratorConcept | +/// ReadableContiguousTileIteratorConcept | +/// WriteableContiguousTileIteratorConcept | +/// MaskedTileIteratorConcept +/// +template < + typename Shape_, + typename Element_, + int AdvanceRank, + typename ThreadMap_, + int AccessSize, + int InterleavedK> +class PredicatedTileIteratorResidualLast< + Shape_, + Element_, + layout::RowMajorInterleaved, + AdvanceRank, + ThreadMap_, + AccessSize, + false> { + public: + static_assert( + AdvanceRank == 0 || AdvanceRank == 1, + "Specialization for pitch-linear iterator may along advance along the " + "contiguous(rank=0) or strided(rank=1) dimension."); + + using Shape = Shape_; + using Element = Element_; + static int const kInterleavedK = InterleavedK; + using Layout = layout::RowMajorInterleaved; + static int const kAdvanceRank = AdvanceRank; + using ThreadMap = ThreadMap_; + + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + + using TensorRef = TensorRef; + using TensorView = TensorView; + using TensorCoord = typename Layout::TensorCoord; + + using Pointer = Element*; + using NonConstPointer = typename platform::remove_const::type*; + + using UnderlyingIterator = PredicatedTileIteratorResidualLast< + layout::PitchLinearShape< + Shape::kColumn * kInterleavedK, + Shape::kRow / kInterleavedK>, + Element, + layout::PitchLinear, + (kAdvanceRank == 0 ? 1 : 0), + ThreadMap, + AccessSize>; + + using AccessType = typename UnderlyingIterator::AccessType; + + /// Fragment object to be loaded or stored + using Fragment = cutlass::Array< + Element, + ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess>; + + /// Predicate vector stores mask to guard accesses + using Mask = typename UnderlyingIterator::Mask; + + /// Parameters object is precomputed state and is host-constructible + class Params { + private: + friend PredicatedTileIteratorResidualLast; + + /// Parameters object + typename UnderlyingIterator::Params params_; + + public: + CUTLASS_HOST_DEVICE + Params() {} + + /// Construct the Params object given a pitch-linear tensor's layout + CUTLASS_HOST_DEVICE + Params(Layout const& layout) + : params_(layout::PitchLinear(layout.stride(0))) {} + + CUTLASS_HOST_DEVICE + Params(typename UnderlyingIterator::Params::Base const& base) + : params_(base) {} + }; + + private: + // + // Data members + // + + /// Underlying pitch-linear tile iterator + UnderlyingIterator iterator_; + + public: + /// Constructs a TileIterator from its precomputed state, threadblock offset, + /// and thread ID + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + /// Precomputed parameters object + Params const& params, + /// Pointer to start of tensor + Pointer pointer, + /// Extent of tensor + TensorCoord extent, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const& threadblock_offset, + int const* indices = + nullptr ///< gather/scatter indices, note no support for + ///< gather/scatter at this specialization + ) + : iterator_( + params.params_, + pointer, + layout::PitchLinearCoord( + extent.column() * kInterleavedK, + extent.row() / kInterleavedK), + thread_id, + layout::PitchLinearCoord( + threadblock_offset.column() * kInterleavedK, + threadblock_offset.row() / kInterleavedK)) {} + + /// Construct a PredicatedTileIteratorResidualLast with zero threadblock + /// offset + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast( + Params const& params, ///< Precomputed parameters object + Pointer pointer, ///< Pointer to start of tensor + TensorCoord extent, ///< Extent of tensor + int thread_id ///< ID of each participating thread + ) + : PredicatedTileIteratorResidualLast( + params, + pointer, + extent, + thread_id, + make_Coord(0, 0)) {} + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + iterator_.add_pointer_offset(pointer_offset); + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast& operator++() { + ++iterator_; + return *this; + } + + /// Advances to the next tile in memory. + /// + /// The first time this method is called, predicates are updated, and the + /// iterator's internal pointer is reverted to the first "steady state" tile. + /// Subsequent calls are lightweight and must only update the internal + /// pointer. + CUTLASS_HOST_DEVICE + PredicatedTileIteratorResidualLast operator++(int) { + PredicatedTileIteratorResidualLast self(*this); + operator++(); + return self; + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void clear_mask(bool enable = true) { + iterator_.clear_mask(enable); + } + + CUTLASS_HOST_DEVICE + void set_residual_tile(bool enable) { + iterator_.set_residual_tile(enable); + } + + /// Clears the predicate set efficiently + CUTLASS_HOST_DEVICE + void enable_mask() { + iterator_.enable_mask(); + } + + /// Sets the predicate mask, overriding value stored in predicate iterator + CUTLASS_HOST_DEVICE + void set_mask(Mask const& mask) { + iterator_.set_mask(mask); + } + + /// Gets the mask + CUTLASS_HOST_DEVICE + void get_mask(Mask& mask) { + iterator_.get_mask(mask); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load_with_pointer_offset(Fragment& frag, Index pointer_offset) { + iterator_.load_with_pointer_offset(frag, pointer_offset); + } + + /// Loads a fragment from memory + CUTLASS_DEVICE + void load(Fragment& frag) { + load_with_pointer_offset(frag, 0); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store_with_pointer_offset(Fragment const& frag, Index pointer_offset) { + iterator_.store_with_pointer_offset(frag, pointer_offset); + } + + /// Store a fragment to memory + CUTLASS_DEVICE + void store(Fragment const& frag) { + store_with_pointer_offset(frag, 0); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace transform +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/examples/41_fused_multi_head_attention/iterators/transpose_warp_iterator.h b/examples/41_fused_multi_head_attention/iterators/transpose_warp_iterator.h new file mode 100644 index 0000000000..f0f8ea6086 --- /dev/null +++ b/examples/41_fused_multi_head_attention/iterators/transpose_warp_iterator.h @@ -0,0 +1,55 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "warp_iterator_from_smem.h" + +template +struct TransposeWarpIterator { + using Iterator = char; + static bool constexpr kSupportsTranspose = false; +}; + +template < + /// Operand identity + cutlass::gemm::Operand Operand, + /// Data type of A elements + typename Element, + typename InstructionShape, + bool kTranspose> +struct TransposeWarpIterator< + cutlass::gemm::warp:: + WarpIteratorFromSmem> { + using Iterator = cutlass::gemm::warp:: + WarpIteratorFromSmem; + static bool constexpr kSupportsTranspose = true; +}; diff --git a/examples/41_fused_multi_head_attention/iterators/warp_iterator_from_smem.h b/examples/41_fused_multi_head_attention/iterators/warp_iterator_from_smem.h new file mode 100644 index 0000000000..d19b1907d5 --- /dev/null +++ b/examples/41_fused_multi_head_attention/iterators/warp_iterator_from_smem.h @@ -0,0 +1,283 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Inspired from + "cutlass/gemm/warp/mma_tensor_op_tile_access_iterator.h" Loads tiles of GEMM + operands from a RowMajor shared-memory layout into registers to use by A100 + TensorCores. + + The difference with "mma_tensor_op_tile_access_iterator.h" is that: + (1) We use "ldmatrix" to load tiles, rather than manual loads (slightly + faster) (2) We support to transpose the operand (eg read `A.transpose()` when + the shared memory holds `A`) + + This is only implemented for the specific shapes. +*/ +#pragma once + +#include + +//////////////////////////////////////////////////////////////////////////////// +namespace cutlass { +namespace gemm { +namespace warp { + +template < + /// Operand identity + Operand Operand_, + /// Data type of A elements + typename Element_, + typename InstructionShape_, + bool kTranspose = false> +class WarpIteratorFromSmem { + public: + /// Shape of tile to load (concept: MatrixShape) + using Shape = cutlass::MatrixShape<32, 32>; + + /// Operand tag + static Operand const kOperand = Operand_; + static_assert( + kOperand == Operand::kA, + "No support for OperandB at the moment"); + + /// Basic check + static_assert( + kOperand == Operand::kA || kOperand == Operand::kB, + "WarpIteratorFromSmem may only be instantiated for A or B operands to warp-level Mma."); + + /// Element type + using Element = Element_; + static_assert(sizeof_bits::value == 16, "Only supported for half"); + + /// Layout of source tile + using Layout = cutlass::layout::RowMajor; + + /// Shape of one matrix product operation (concept: MatrixShape) + using InstructionShape = InstructionShape_; + static_assert(InstructionShape::kRow == 16, "Only supports 16x8x8 / 16x8x16"); + static_assert( + InstructionShape::kColumn == 8 || InstructionShape::kColumn == 16, + "Only supports 16x8x8 / 16x8x16"); + + /// Delta between *MMA operations (in units of *MMA operations, concept: + /// MatrixShape) + static int const kOpDelta = 1; + + /// Number of participating threads + static int const kThreads = 32; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + /// Index type + using Index = typename TensorRef::Index; + + /// Long Index type + using LongIndex = typename TensorRef::LongIndex; + + /// Coordinate for an element in the tensor + using TensorCoord = typename TensorRef::TensorCoord; + + /// Number of elements accessed per Shared Memory load + static int const kElementsPerAccess = + (sizeof_bits::value >= 32 ? 1 + : 32 / sizeof_bits::value); + + using InstructionCount = MatrixShape< + Shape::kRow / InstructionShape::kRow, + Shape::kColumn / InstructionShape::kColumn>; + + static int const kIterations = (kOperand == Operand::kA) + ? InstructionCount::kColumn + : InstructionCount::kRow; + + public: + // + // Derived quantities + // + + /// Fragment object holding a thread's part of a tile + using Fragment = Array< + Element, + (kOperand == Operand::kA) + ? (Shape::kRow* InstructionShape::kColumn / kThreads) + : (Shape::kColumn* InstructionShape::kRow / kThreads)>; + + /// Memory access type + // using AccessType = AlignedArray; + using AccessType = Array; + + static int constexpr kWarpShapeDivisibleInner = + (kOperand == Operand::kA ? InstructionShape::kColumn + : InstructionShape::kRow); + static int constexpr kAccessesInner = + (kWarpShapeDivisibleInner / kElementsPerAccess) / 4; + // Number of 32bits tiles to load per `ldmatrix` + static int const kTilesPerInstruction = InstructionShape::kRow / 8; + static_assert(kTilesPerInstruction == 2, "Only supports 16x8x16 and 16x8x8"); + + private: + /// Underlying tensor reference + TensorRef ref_; + + /// Origin + MatrixCoord origin_; + + /// Iterations in a tile + int iterations_; + + public: + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + WarpIteratorFromSmem(TensorRef const& ref, int lane_id) + : WarpIteratorFromSmem(ref, {Shape::kRow, Shape::kColumn}, lane_id) {} + CUTLASS_HOST_DEVICE + WarpIteratorFromSmem(TensorRef const& ref, TensorCoord extent, int lane_id) + : ref_(ref), iterations_(0) { + // See also: + // https://docs.nvidia.com/cuda/archive/11.7.1/parallel-thread-execution/index.html#warp-level-matrix-fragment-mma-1688 + // 16x8x8: kAccessesInner = 1 (1 ldmatrix.x4) + // 16x8x16: kAccessesInner = 2 (2 ldmatrix.x4) + int ldsm_vec_num = (lane_id >> 3); + if (kOperand == Operand::kA) { + origin_ = MatrixCoord(lane_id % 8, 0); + static_assert( + InstructionCount::kRow * kTilesPerInstruction == 4, + "can't use ldmatrix.x4"); + int access_m_idx = ldsm_vec_num % kTilesPerInstruction; + int inner_idx = (ldsm_vec_num / kTilesPerInstruction) % kAccessesInner; + int inst_m_idx = ldsm_vec_num / (kTilesPerInstruction * kAccessesInner); + MatrixCoord offset( + access_m_idx * 8 + inst_m_idx * InstructionShape::kRow, + inner_idx * 4 * kElementsPerAccess); + if (kTranspose) { + offset = MatrixCoord(offset.column(), offset.row()); + } + origin_ += offset; + } else { + // Note: This is not tested or used + origin_ = MatrixCoord(0, lane_id % 8); + static_assert(InstructionCount::kColumn * kAccessesInner == 4, ""); + CUTLASS_PRAGMA_UNROLL + for (int inst_n_idx = 0; inst_n_idx < InstructionCount::kColumn; + ++inst_n_idx) { + CUTLASS_PRAGMA_UNROLL + for (int inner_idx = 0; inner_idx < kAccessesInner; ++inner_idx) { + int access_idx = inner_idx + kAccessesInner * inst_n_idx; + + MatrixCoord offset( + inner_idx * 4 * kElementsPerAccess, inst_n_idx * 8); + + if (access_idx == ldsm_vec_num) { + if (kTranspose) { + offset = MatrixCoord(offset.column(), offset.row()); + } + origin_ += offset; + } + } + } + } + + ref_.add_coord_offset(origin_); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole + /// tiles + CUTLASS_HOST_DEVICE + WarpIteratorFromSmem& add_tile_offset(TensorCoord const& tile_offset) { + TensorCoord coord_offset( + tile_offset.row() * Shape::kRow, tile_offset.column() * Shape::kColumn); + if (kTranspose) { + coord_offset = TensorCoord{coord_offset.column(), coord_offset.row()}; + } + origin_ += coord_offset; + + ref_.add_coord_offset(coord_offset); + + return *this; + } + + /// Advances the iterator along the advance dimension + CUTLASS_DEVICE + void advance() { + if (kOperand == Operand::kA) { + add_tile_offset({0, 1}); + } else { + add_tile_offset({1, 0}); + } + + iterations_ = 0; + } + + /// increase iterations in a tile + CUTLASS_HOST_DEVICE + WarpIteratorFromSmem& operator++() { + iterations_++; + + if (iterations_ >= kIterations) + advance(); + + return *this; + } + + /// Loads a fragment from memory at the location pointed to by the iterator. + CUTLASS_DEVICE + void load(Fragment& frag) const { + AccessType* access_ptr = reinterpret_cast(&frag); + using LoadLayout = typename platform:: + conditional::type; + + CUTLASS_PRAGMA_UNROLL + for (int access_m_idx = 0; access_m_idx < + (InstructionCount::kRow * kTilesPerInstruction * kAccessesInner) / 4; + ++access_m_idx) { + MatrixCoord offset; + if (kOperand == Operand::kA) { + offset = MatrixCoord( + access_m_idx * 16, iterations_ * InstructionShape::kColumn); + } else { + offset = MatrixCoord(iterations_ * InstructionShape::kRow, 0); + } + if (kTranspose) { + offset = MatrixCoord(offset.column(), offset.row()); + } + cutlass::arch::ldsm( + access_ptr[access_m_idx], ref_.data() + ref_.offset(offset)); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass +//////////////////////////////////////////////////////////////////////////////// diff --git a/examples/41_fused_multi_head_attention/kernel_backward.h b/examples/41_fused_multi_head_attention/kernel_backward.h new file mode 100644 index 0000000000..6fd94a6c58 --- /dev/null +++ b/examples/41_fused_multi_head_attention/kernel_backward.h @@ -0,0 +1,2554 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include +#include +#include + +#include +#include + +#ifdef HAS_PYTORCH +#include +#include +#include +#include +#endif + +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/scale_type.h" +#include "cutlass/fast_math.h" +#include "cutlass/functional.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/vector.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" + +#include "debug_utils.h" +#include "gemm_kernel_utils.h" + +#include "cutlass/epilogue/thread/linear_combination_relu.h" +#include "cutlass/epilogue/threadblock/epilogue_smem_accumulator.h" +#include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h" +#include "cutlass/epilogue/warp/tile_iterator_tensor_op.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/gemm/kernel/default_gemm.h" +#include "cutlass/gemm/threadblock/default_mma.h" +#include "cutlass/gemm/threadblock/default_mma_core_simt.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" +#include "cutlass/integer_subbyte.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/platform/platform.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" +#include "cutlass/transform/threadblock/vector_iterator.h" +#include "epilogue/epilogue_pipelined.h" +#include "iterators/epilogue_predicated_tile_iterator.h" + +#include "gemm/custom_mma.h" +#include "gemm/find_default_mma.h" +#include "gemm/mma_accum_lambda_iterator.h" +#include "gemm/mma_from_smem.h" +#include "transform/tile_smem_loader.h" + +using namespace gemm_kernel_utils; + +namespace { + +template +struct GmemTile { + /* + Helper functions to efficient store/load RF to gmem + + GEMM accumulators have a particular format on A100, and + it takes some compute/shared-memory to rearrange them to + a RowMajor or ColumnMajor format in global memory through + an Epilogue. The same complexity goes for loading into RF. + + This class loads/stores RF as they are, and can be used for + efficient accumulation across gemms for instance: + + ``` + GmemTile tile; + for (int i = 0; i < N; ++i) { + // ... + + Fragment accum; + if (i == 0) { + accum.clear(); + } else { + tile.load(accum); + } + mma(accum, ...); + if (i < N-1) { + // Store for next GEMM + tile.store(accum); + } else { + // Store in tensor (eg RowMajor) + epilogue(accum); + } + + // ... + } + ``` + */ + + // 128bits per thread + using AccessType = cutlass::Array; + static constexpr int32_t kBytes = sizeof(AccessType); + static constexpr int32_t kStride = kNumThreads * AccessType::kElements; + static constexpr int32_t kNumIters = + FragmentType::kElements / AccessType::kElements; + static constexpr int32_t kElementsStored = + kNumThreads * FragmentType::kElements; + static_assert( + FragmentType::kElements % AccessType::kElements == 0, + "fragment not aligned on 128 bits"); + + float* ptr; + + CUTLASS_DEVICE void load(FragmentType& fragment, int thread_id) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kNumIters; ++i) { + AccessType* __restrict__ gmem_ptr = reinterpret_cast( + ptr + thread_id * AccessType::kElements + i * kStride); + AccessType sub_fragment; + cutlass::arch::global_load( + sub_fragment, gmem_ptr, true); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < AccessType::kElements; ++j) { + fragment[i * AccessType::kElements + j] = sub_fragment[j]; + } + } + } + + CUTLASS_DEVICE void store(FragmentType const& fragment, int thread_id) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kNumIters; ++i) { + AccessType* __restrict__ gmem_ptr = reinterpret_cast( + ptr + thread_id * AccessType::kElements + i * kStride); + AccessType sub_fragment; + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < AccessType::kElements; ++j) { + sub_fragment[j] = fragment[i * AccessType::kElements + j]; + } + cutlass::arch::global_store( + sub_fragment, gmem_ptr, true); + } + } + + CUTLASS_DEVICE void storeAtomicAdd( + FragmentType const& fragment, + int thread_id) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kNumIters; ++i) { + float* gmem_ptr = ptr + thread_id * AccessType::kElements + i * kStride; + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < AccessType::kElements; ++j) { + float val = fragment[i * AccessType::kElements + j]; + float* ptr = gmem_ptr + j; + atomicAdd(ptr, val); + } + } + } +}; + +struct AtomicLock { + CUTLASS_DEVICE static void acquire( + int32_t* lock, + int set_val, + int thread_id) { + if (thread_id == 0) { + while (atomicCAS(lock, 0 /*cmp*/, set_val /*setval*/) != set_val) { +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 + __nanosleep(40); +#endif + } + } + __syncthreads(); + } + CUTLASS_DEVICE static void release(int32_t* lock, int thread_id) { + if (thread_id == 0) { + int status = 0; +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 700 + asm volatile("st.global.release.gpu.b32 [%0], %1;\n" + : + : "l"(lock), "r"(status)); +#else + asm volatile("st.global.cg.b32 [%0], %1;\n" : : "l"(lock), "r"(status)); +#endif + } + } +}; + +template +constexpr int getWarpsPerSmBw() { + bool is_half = !cutlass::platform::is_same::value; + if (Arch::kMinComputeCapability >= 80) { + return is_half ? 12 : 8; + } + return 8; +} +} // namespace + +template < + // which arch we target (eg `cutlass::arch::Sm80`) + typename ArchTag_, + // input/output type + typename scalar_t_, + // run optimized kernel because memory accesses will be aligned + bool kIsAligned_, + // use dropout if enabled + bool kApplyDropout_, + // when doing a GEMM, preload the next one (uses more shmem) + bool kPreload_, + // block dimensions + int kBlockSizeI_, + int kBlockSizeJ_, + // upperbound on `max(value.shape[-1], query.shape[-1])` + int kMaxK_ = (int)cutlass::platform::numeric_limits::max(), + // assumes that `cu_seqlen` is None, and + // (1) `num_queries % kBlockSizeI == 0` + // (2) `num_keys % kBlockSizeJ == 0` + bool kKeysQueriesAlignedToBlockSize_ = false, + // Allows to parallelize across keys + bool kEnableSplitKeys_ = true> +struct AttentionBackwardKernel { + enum CustomMaskType { + NoCustomMask = 0, + CausalFromTopLeft = 1, + CausalFromBottomRight = 2, + NumCustomMaskTypes, + }; + using scalar_t = scalar_t_; + using output_t = scalar_t; + using output_accum_t = float; + using lse_scalar_t = float; + using accum_t = float; + using ArchTag = ArchTag_; + static constexpr bool kIsAligned = kIsAligned_; + static constexpr bool kApplyDropout = kApplyDropout_; + static constexpr bool kPreload = kPreload_; + static constexpr int kBlockSizeI = kBlockSizeI_; + static constexpr int kBlockSizeJ = kBlockSizeJ_; + static constexpr int kMaxK = kMaxK_; + static constexpr bool kKeysQueriesAlignedToBlockSize = + kKeysQueriesAlignedToBlockSize_; + + static constexpr int64_t kWarpSize = 32; + + // If this is true, we store and accumulate dK/dV in RF + // rather than going back to gmem everytime + static constexpr bool kIsHalf = cutlass::sizeof_bits::value <= 16; + static constexpr bool kOutputInRF = kIsHalf && kMaxK <= kBlockSizeI; + static_assert( + !kPreload || + (kIsHalf && ArchTag::kMinComputeCapability >= 80 && kOutputInRF), + "preload MMA not supported"); + static constexpr bool kPrologueQK = kPreload; + static constexpr bool kPrologueGV = kPreload; + static constexpr bool kPrologueDOV = kPreload; + static constexpr bool kPrologueGQ = kPreload; + static constexpr bool kPrologueGK = kPreload; + + static constexpr int64_t kNumWarpsPerBlock = + (kBlockSizeI * kBlockSizeJ) / (32 * 32); + + // Compute delta for the f16 kernels + // TODO: Figure out why it's slower on the f32 kernels + // (something due to RF pressure?) + // TODO: Remove condition on `kOutputInRF` - this is needed to work + // around a compiler bug on V100, not exactly sure why but I spent + // too much time on this already. Reproducible with + // (B, Mq, Mkv, K) = (1, 1, 1, 136) for instance + static constexpr bool kKernelComputesDelta = + kIsHalf && (kOutputInRF || ArchTag::kMinComputeCapability != 70); + + // Launch bounds + static constexpr int64_t kNumThreads = kWarpSize * kNumWarpsPerBlock; + static constexpr int64_t kMinBlocksPerSm = + getWarpsPerSmBw() / kNumWarpsPerBlock; + + using GemmType = DefaultGemmType; + using DefaultConfig = + typename cutlass::gemm::device::DefaultGemmConfiguration< + typename GemmType::OpClass, + ArchTag, + scalar_t, + scalar_t, + scalar_t, // ElementC + accum_t // ElementAccumulator + >; + static constexpr auto kOptimalAlignement = cutlass::platform::max( + DefaultConfig::kAlignmentA, + DefaultConfig::kAlignmentB); + static constexpr auto kMinimumAlignment = GemmType::kMinimumAlignment; + + struct MatmulQK { + /* + attn_T = k_j @ q_i.transpose(-2, -1) # matmul + attn_T = (attn_T - logsumexp[i_start:i_end].unsqueeze(1).transpose(-2, + -1)).exp() # epilogue + + with attn_T.shape = (kBlockSizeJ, kBlockSizeI) + */ + using ThreadblockShape = + cutlass::gemm::GemmShape; + using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; + using DefaultMma = typename cutlass::gemm::threadblock::DefaultMma< + scalar_t, // ElementA + cutlass::layout::RowMajor, // LayoutA + kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment, + scalar_t, // ElementB + cutlass::layout::ColumnMajor, // LayoutB + kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment, + accum_t, // ElementC + cutlass::layout::RowMajor, // LayoutC + typename GemmType::OpClass, + ArchTag, + ThreadblockShape, + WarpShape, + typename GemmType::InstructionShape, + DefaultConfig::kStages, + typename GemmType::Operator, + false, // AccumulatorsInRowMajor = false, + cutlass::gemm::SharedMemoryClearOption::kNone>; + using MmaCore = typename DefaultMma::MmaCore; + using Mma = + typename MakeCustomMma::Mma; + + // used for efficient load of bias tile (Bij) from global memory to shared + // memory + using BiasLoader = TileSmemLoader< + scalar_t, + // Bij is applied to transposed attn matrix tile (Pij.T). Bij is loaded + // row-major but needs to have transposed shape so we get the same + // elements. + cutlass::MatrixShape, + MmaCore::kThreads, + // input restriction: kv_len has to be a multiple of this value + 128 / cutlass::sizeof_bits::value>; + + // Epilogue to store to shared-memory in a format that we can use later for + // the second matmul + using B2bGemm = typename cutlass::gemm::threadblock::B2bGemm< + typename Mma::Operator::IteratorC, + typename Mma::Operator, + scalar_t, + WarpShape, + ThreadblockShape>; + using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator< + typename Mma::Operator::IteratorC, + accum_t, + kWarpSize>::Iterator; + using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage; + }; + + struct MatmulGradV { + /* + grad_v[j_start:j_end] += attn_T @ do_i # matmul + + Dimensions: (kBlockSizeJ * kNumWarpsPerBlock, kBlockSizeI, K) + (we might need to iterate multiple times on K) + */ + using ThreadblockShape = + cutlass::gemm::GemmShape; + using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; + using InstructionShape = typename GemmType::InstructionShape; + + using DefaultGemm = cutlass::gemm::kernel::DefaultGemm< + scalar_t, // ElementA, + cutlass::layout::RowMajor, // LayoutA, + DefaultConfig::kAlignmentA, + scalar_t, // ElementB, + cutlass::layout::RowMajor, // LayoutB, + kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment, + output_t, + cutlass::layout::RowMajor, // LayoutC, + accum_t, + typename GemmType::OpClass, + ArchTag, + ThreadblockShape, + WarpShape, + typename GemmType::InstructionShape, + typename DefaultConfig::EpilogueOutputOp, + void, // ThreadblockSwizzle - not used + DefaultConfig::kStages, + false, // SplitKSerial + typename GemmType::Operator>; + + // if dropout: + // for computing dVj += (Pij.T * Zij) @ dOi + // Pij_dropped.T = Pij.T * Zij is computed on the fly as fragments of + // Pij.T are loaded in. The reason we do it this way is because Pij.T and + // Zij are reused in later steps, while Pij_dropped.T is only needed in + // this step. computing Pij_dropped.T on the fly allows us to avoid + // keeping all 3 of Pij_dropped.T, Pij.T, and Zij in shared memory at the + // same time. + // if no dropout: + // for computing dVj += Pij.T @ dOi + using WarpIteratorA = typename cutlass::gemm::threadblock:: + DefaultWarpIteratorAFromSharedMemory< + typename DefaultGemm::Mma::Operator::Shape, // WarpShape + typename DefaultGemm::Mma::Operator:: + InstructionShape, // InstructionShape + typename DefaultGemm::Mma::Operator:: + IteratorA, // RegularWarpIterator + typename DefaultGemm::Mma::Policy // Policy + >::WarpIterator; + using DefaultMmaFromSmem = + typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< + typename DefaultGemm::Mma, + MatmulQK::AccumulatorSharedStorage::Shape::kN, + WarpIteratorA, + kApplyDropout>; // kScaleOperandA + + using Mma = typename DefaultMmaFromSmem::Mma; + using IteratorB = typename Mma::IteratorB; + using WarpCount = typename Mma::WarpCount; + + // Epilogue + using DefaultOutputOp = typename DefaultConfig::EpilogueOutputOp; + using DefaultEpilogue = typename DefaultGemm::Epilogue; + using OutputTileIterator = + typename cutlass::epilogue::threadblock::MakePrefetchableIterator< + typename DefaultEpilogue::OutputTileIterator>::Iterator; + using AccumTileGmem = GmemTile; + }; + + struct MatmulDOIVJ { + /* + doi_t_vj = do_i @ v_j.transpose(-2, -1) # matmul + tmp = (doi_t_vj - Di.unsqueeze(1)) * attn # inplace / epilogue? + */ + using ThreadblockShape = + cutlass::gemm::GemmShape; + using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; + + using ElementC = output_t; + using ElementAccum = accum_t; + + // no-op output op - epilogue just stores result to global memory + using BiasGradEpilogueOutputOp = + typename cutlass::epilogue::thread::LinearCombination< + ElementC, + DefaultConfig::EpilogueOutputOp::kCount, + typename DefaultConfig::EpilogueOutputOp::ElementAccumulator, + typename DefaultConfig::EpilogueOutputOp::ElementCompute, + cutlass::epilogue::thread::ScaleType::Nothing>; + + using DefaultGemm = typename cutlass::gemm::kernel::DefaultGemm< + scalar_t, // ElementA + cutlass::layout::RowMajor, // LayoutA + kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment, + scalar_t, // ElementB + cutlass::layout::ColumnMajor, // LayoutB + kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment, + ElementC, // ElementC + cutlass::layout::RowMajor, // LayoutC + ElementAccum, // ElementAccumulator + typename GemmType::OpClass, + ArchTag, + ThreadblockShape, + WarpShape, + typename GemmType::InstructionShape, + BiasGradEpilogueOutputOp, // EpilogueOutputOp + void, // ThreadblockSwizzle (not used) + // multiple preloads, dropout Zij tile, and 3 stages push us over shared + // memory capacity on A100. set a ceiling on number of stages to save + // shared memory if dropout is in use. + kPreload && kApplyDropout && (kBlockSizeI * kBlockSizeJ > 64 * 64) + ? cutlass::const_min(2, DefaultConfig::kStages) + : DefaultConfig::kStages, // Stages + false, // SplitKSerial + typename GemmType::Operator, + cutlass::gemm::SharedMemoryClearOption::kNone>; + using Mma = typename MakeCustomMma::Mma; + using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator< + typename Mma::Operator::IteratorC, + ElementAccum, + kWarpSize>::Iterator; + + // epilogue used to write bias gradient, which is just the output of this + // matmul with some operations applied to the fragment + using BiasGradEpilogue = typename DefaultGemm::Epilogue; + + // Epilogue to store to shared-memory in a format that we can use later for + // the second matmul + using B2bGemm = typename cutlass::gemm::threadblock::B2bGemm< + typename DefaultGemm::Mma::Operator::IteratorC, + typename DefaultGemm::Mma::Operator, + scalar_t, + WarpShape, + ThreadblockShape>; + using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage; + }; + + struct MatmulGradQ { + // grad_q <- tmp @ k_j + using ThreadblockShape = + cutlass::gemm::GemmShape; + using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; + using InstructionShape = typename GemmType::InstructionShape; + + using DefaultGemm = cutlass::gemm::kernel::DefaultGemm< + scalar_t, // ElementA, + cutlass::layout::RowMajor, // LayoutA, + DefaultConfig::kAlignmentA, + scalar_t, // ElementB, + cutlass::layout::RowMajor, // LayoutB, + kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment, + output_t, + cutlass::layout::RowMajor, // LayoutC, + accum_t, + typename GemmType::OpClass, + ArchTag, + ThreadblockShape, + WarpShape, + typename GemmType::InstructionShape, + typename DefaultConfig::EpilogueOutputOp, + void, // ThreadblockSwizzle - not used + DefaultConfig::kStages, + false, // SplitKSerial + typename GemmType::Operator>; + + using WarpIteratorA = typename cutlass::gemm::threadblock:: + DefaultWarpIteratorAFromSharedMemory< + typename DefaultGemm::Mma::Operator::Shape, + typename DefaultGemm::Mma::Operator::InstructionShape, + typename DefaultGemm::Mma::Operator::IteratorA, + typename DefaultGemm::Mma::Policy>::WarpIterator; + using DefaultMmaFromSmem = + typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< + typename DefaultGemm::Mma, + MatmulDOIVJ::AccumulatorSharedStorage::Shape::kN, + WarpIteratorA, + false>; // kScaleOperandA + using Mma = typename DefaultMmaFromSmem::Mma; + using IteratorB = typename Mma::IteratorB; + using WarpCount = typename Mma::WarpCount; + + // Epilogue + using DefaultOutputOp = typename DefaultConfig::EpilogueOutputOp; + using DefaultEpilogue = typename DefaultGemm::Epilogue; + using OutputTileIterator = + typename cutlass::epilogue::threadblock::MakePrefetchableIterator< + typename DefaultEpilogue::OutputTileIterator>::Iterator; + using AccumTileGmem = GmemTile; + }; + struct MatmulGradK { + // grad_k <- tmp.transpose(-2, -1) @ q_i + using ThreadblockShape = + cutlass::gemm::GemmShape; + using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; + using InstructionShape = typename GemmType::InstructionShape; + + using DefaultGemm = cutlass::gemm::kernel::DefaultGemm< + scalar_t, // ElementA, + cutlass::layout::RowMajor, // LayoutA, + DefaultConfig::kAlignmentA, + scalar_t, // ElementB, + cutlass::layout::RowMajor, // LayoutB, + kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment, + output_t, + cutlass::layout::RowMajor, // LayoutC, + accum_t, + typename GemmType::OpClass, + ArchTag, + ThreadblockShape, + WarpShape, + typename GemmType::InstructionShape, + typename DefaultConfig::EpilogueOutputOp, + void, // ThreadblockSwizzle - not used + DefaultConfig::kStages, + false, // SplitKSerial + typename GemmType::Operator>; + + using WarpIteratorA = typename cutlass::gemm::threadblock:: + DefaultWarpIteratorAFromSharedMemory< + typename DefaultGemm::Mma::Operator::Shape, + typename DefaultGemm::Mma::Operator::InstructionShape, + typename DefaultGemm::Mma::Operator::IteratorA, + typename DefaultGemm::Mma::Policy>::WarpIterator; + using DefaultMmaFromSmemN = + typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< + typename DefaultGemm::Mma, + MatmulQK::AccumulatorSharedStorage::Shape::kN, // kMaxK + WarpIteratorA, + false>; // kScaleOperandA + using DefaultMmaFromSmemT = + typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< + typename DefaultGemm::Mma, + MatmulDOIVJ::AccumulatorSharedStorage::Shape::kM, // kMaxK + WarpIteratorA, + false, // kScaleOperandA + kPreload>; // kTransposeA + using DefaultMmaFromSmem = typename cutlass::platform::conditional< + DefaultMmaFromSmemT::kIsTransposedA, + DefaultMmaFromSmemT, + DefaultMmaFromSmemN>::type; + using Mma = typename DefaultMmaFromSmem::Mma; + using IteratorB = typename Mma::IteratorB; + using WarpCount = typename Mma::WarpCount; + + // Epilogue + using DefaultOutputOp = typename DefaultConfig::EpilogueOutputOp; + using DefaultEpilogue = typename DefaultGemm::Epilogue; + using OutputTileIterator = + typename cutlass::epilogue::threadblock::MakePrefetchableIterator< + typename DefaultEpilogue::OutputTileIterator>::Iterator; + using AccumTileGmem = GmemTile; + }; + + static constexpr bool kEnableSplitKeys = kEnableSplitKeys_; + + static constexpr bool kNeedsAccumGradQ = kEnableSplitKeys || + !cutlass::platform::is_same::value; + static constexpr bool kNeedsAccumGradK = !kOutputInRF && + !cutlass::platform::is_same::value; + static constexpr bool kNeedsAccumGradV = !kOutputInRF && + !cutlass::platform::is_same::value; + + struct GradQTempStorage { + int32_t lock; + int32_t counter; + int32_t pad[2]; // pad to 128bits + output_accum_t buffer[MatmulGradQ::AccumTileGmem::kElementsStored]; + }; + + struct Params { + // Input tensors + scalar_t* query_ptr = nullptr; // [Mq, nH, K] + scalar_t* key_ptr = nullptr; // [Mk, nH, K] + scalar_t* value_ptr = nullptr; // [Mk, nH, Kv] + scalar_t* bias_ptr = nullptr; + lse_scalar_t* logsumexp_ptr = nullptr; // [nH, Mq] + scalar_t* output_ptr = nullptr; // [Mq, nH, Kv] + scalar_t* grad_output_ptr = nullptr; // [Mq, nH, Kv] + accum_t* delta_ptr = nullptr; // [nH, Mq] + int32_t* cu_seqlens_q_ptr = nullptr; + int32_t* cu_seqlens_k_ptr = nullptr; + + // Output tensors + output_t* grad_query_ptr = nullptr; // [Mq, nH, K] + output_t* grad_key_ptr = nullptr; // [Mk, nH, K] + output_t* grad_value_ptr = nullptr; // [Mk, nH, Kv] + output_t* grad_bias_ptr = nullptr; + + // Accumulators + output_accum_t* workspace = nullptr; // [Mq, Kq] + [Mkv, Kq] + [Mkv, Kv] + output_accum_t* workspace_gv = + nullptr; // (will be calculated by the kernel) + GradQTempStorage* workspace_gq = + nullptr; // (will be calculated by the kernel) + + // Scale + accum_t scale = 1.0f; + + // Dimensions/strides + int32_t head_dim = -1; + int32_t head_dim_value = -1; + int32_t num_queries = -1; + int32_t num_keys = -1; + int32_t num_heads = -1; + uint8_t custom_mask_type = NoCustomMask; + + int32_t q_strideM = -1; + int32_t k_strideM = -1; + int32_t v_strideM = -1; + int32_t bias_strideM = 0; + int32_t gO_strideM = -1; + int32_t gB_strideM = -1; + int8_t gQKV_strideM_multiplier = 1; // 3 for packed, 1 otherwise + +#ifdef HAS_PYTORCH + // dropout + at::PhiloxCudaState rng_engine_inputs = {0, 0}; +#endif + // RNG sequence offset based on batch_id and head_id + unsigned long long dropout_batch_head_rng_offset = 0; + float dropout_prob = 0.0f; + + CUTLASS_HOST_DEVICE int32_t o_strideM() const { + return head_dim_value * num_heads; + } + CUTLASS_HOST_DEVICE int32_t gQ_strideM() const { + return gQKV_strideM_multiplier * num_heads * head_dim; + } + CUTLASS_HOST_DEVICE int32_t gK_strideM() const { + return gQKV_strideM_multiplier * num_heads * head_dim; + } + CUTLASS_HOST_DEVICE int32_t gV_strideM() const { + return gQKV_strideM_multiplier * num_heads * head_dim_value; + } + + // Everything below is only used in `advance_to_block` + // and shouldn't use registers + int64_t o_strideH = -1; + int32_t q_strideH = -1; + int32_t k_strideH = -1; + int32_t v_strideH = -1; + int64_t bias_strideH = 0; + int64_t o_strideB = -1; + int64_t q_strideB = -1; + int64_t k_strideB = -1; + int64_t v_strideB = -1; + int64_t bias_strideB = 0; + int64_t lse_strideB = -1; + int64_t lse_strideH = -1; + int64_t delta_strideB = -1; + int64_t delta_strideH = -1; + int32_t num_batches = -1; + int16_t num_splits_key = 1; // We use `gridDim.x` inside kernel + + int64_t gO_strideB = 0; + int64_t gQ_strideB = 0; + int64_t gK_strideB = 0; + int64_t gV_strideB = 0; + int64_t gB_strideB = 0; + int64_t gO_strideH = 0; + int64_t gQ_strideH = 0; + int64_t gK_strideH = 0; + int64_t gV_strideH = 0; + int64_t gB_strideH = 0; + + CUTLASS_DEVICE int16_t num_splits_key_device() const { + return kEnableSplitKeys ? gridDim.x : 1; + } + CUTLASS_DEVICE int16_t split_key_device() const { + return kEnableSplitKeys ? blockIdx.x : 0; + } + + CUTLASS_DEVICE bool advance_to_block() { + int64_t batch_id = blockIdx.z; + int32_t head_id = blockIdx.y; + + if (kNeedsAccumGradQ || kNeedsAccumGradK || kNeedsAccumGradV) { + assert(workspace_size() == 0 || workspace != nullptr); + + workspace += (batch_id * num_heads + head_id) * workspace_strideBH(); + workspace = warp_uniform(workspace); + workspace_gv = workspace + workspace_elements_gk(); + workspace_gq = + (GradQTempStorage*)(workspace_gv + workspace_elements_gv()); + if (kEnableSplitKeys) { + workspace_gv += workspace_elements_gv() * split_key_device() / + num_splits_key_device(); + workspace += workspace_elements_gk() * split_key_device() / + num_splits_key_device(); + } + } else { + workspace = nullptr; + } + + // Advance pointers that depend on the total concatenated + // number of queries, as `num_queries` is modified in the block + // below + dropout_batch_head_rng_offset = + batch_id * (num_heads * num_queries * num_keys) + + head_id * (num_queries * num_keys); + logsumexp_ptr += batch_id * lse_strideB + head_id * lse_strideH; + + if (cu_seqlens_q_ptr != nullptr) { + assert(cu_seqlens_k_ptr != nullptr); + cu_seqlens_q_ptr += batch_id; + cu_seqlens_k_ptr += batch_id; + int32_t q_start = cu_seqlens_q_ptr[0]; + int32_t k_start = cu_seqlens_k_ptr[0]; + int64_t q_next_start = cu_seqlens_q_ptr[1]; + int64_t k_next_start = cu_seqlens_k_ptr[1]; + assert(q_next_start - q_start <= num_queries); + assert(k_next_start - k_start <= num_keys); + num_queries = q_next_start - q_start; + num_keys = k_next_start - k_start; + + // Jump manually + batch_id = 0; + + query_ptr += q_start * q_strideM; + key_ptr += k_start * k_strideM; + value_ptr += k_start * v_strideM; + assert(bias_ptr == nullptr); + assert(grad_bias_ptr == nullptr); + output_ptr += q_start * o_strideM(); + grad_output_ptr += q_start * gO_strideM; + delta_ptr += q_start; + + grad_query_ptr += q_start * gQ_strideM(); + grad_key_ptr += k_start * gK_strideM(); + grad_value_ptr += k_start * gV_strideM(); + } + + query_ptr += batch_id * q_strideB + head_id * q_strideH; + key_ptr += batch_id * k_strideB + head_id * k_strideH; + value_ptr += batch_id * v_strideB + head_id * v_strideH; + if (bias_ptr != nullptr) { + bias_ptr += batch_id * bias_strideB + head_id * bias_strideH; + } + output_ptr += batch_id * o_strideB + head_id * o_strideH; + grad_output_ptr += batch_id * gO_strideB + head_id * gO_strideH; + delta_ptr += batch_id * delta_strideB + head_id * delta_strideH; + + grad_query_ptr += batch_id * gQ_strideB + head_id * gQ_strideH; + grad_key_ptr += batch_id * gK_strideB + head_id * gK_strideH; + grad_value_ptr += batch_id * gV_strideB + head_id * gV_strideH; + if (grad_bias_ptr != nullptr) { + grad_bias_ptr += batch_id * gB_strideB + head_id * gB_strideH; + } + + // Some values are modified above + // Signal to the compiler that they are the same in all threads + // and can be stored in warp-uniform registers (Sm75+) + num_queries = warp_uniform(num_queries); + num_keys = warp_uniform(num_keys); + custom_mask_type = warp_uniform(custom_mask_type); + + query_ptr = warp_uniform(query_ptr); + key_ptr = warp_uniform(key_ptr); + value_ptr = warp_uniform(value_ptr); + bias_ptr = warp_uniform(bias_ptr); + logsumexp_ptr = warp_uniform(logsumexp_ptr); + output_ptr = warp_uniform(output_ptr); + grad_output_ptr = warp_uniform(grad_output_ptr); + delta_ptr = warp_uniform(delta_ptr); + + grad_query_ptr = warp_uniform(grad_query_ptr); + grad_key_ptr = warp_uniform(grad_key_ptr); + grad_value_ptr = warp_uniform(grad_value_ptr); + grad_bias_ptr = warp_uniform(grad_bias_ptr); + +#if 0 + PRINT_T0("[b:%d h:%d] dp[0]:%f Q:%f K:%f V:%f LSE:%f", + int(blockIdx.z), int(blockIdx.y), + float(delta_ptr[0]), + float(query_ptr[0]), float(key_ptr[0]), float(value_ptr[0]), + float(logsumexp_ptr[0]) + ) +#endif + return true; + } + + __host__ dim3 getBlocksGrid() const { + return dim3(num_splits_key, num_heads, num_batches); + } + __host__ dim3 getThreadsGrid() const { + return dim3(kWarpSize * kNumWarpsPerBlock, 1, 1); + } + CUTLASS_HOST_DEVICE int64_t workspace_elements_gk() const { + if (!kNeedsAccumGradK) { + return 0; + } + return num_splits_key * align_up(num_keys, (int32_t)kBlockSizeJ) * + align_up(head_dim, (int32_t)kBlockSizeI); + } + CUTLASS_HOST_DEVICE int64_t workspace_elements_gv() const { + if (!kNeedsAccumGradV) { + return 0; + } + return num_splits_key * align_up(num_keys, (int32_t)kBlockSizeJ) * + align_up(head_dim_value, (int32_t)kBlockSizeI); + } + CUTLASS_HOST_DEVICE int64_t workspace_elements_gq() const { + if (!kNeedsAccumGradQ) { + return 0; + } + int num_blocks = ceil_div(num_queries, kBlockSizeI); + int num_cols = ceil_div(head_dim, MatmulGradQ::ThreadblockShape::kN); + return num_blocks * num_cols * sizeof(GradQTempStorage) / + sizeof(output_accum_t); + } + CUTLASS_HOST_DEVICE int64_t workspace_strideBH() const { + // Aligned on 128bits + return align_up( + workspace_elements_gk() + workspace_elements_gv() + + workspace_elements_gq(), + int64_t(4)); + } + CUTLASS_HOST_DEVICE int64_t workspace_size() const { + // Returns size of buffer we need to run this kernel + return num_batches * num_heads * workspace_strideBH() * sizeof(float); + } + CUTLASS_HOST_DEVICE bool should_zero_workspace() const { + return num_splits_key > 1; + } + }; + + // shared storage for keeping Zij matrix. not needed if we aren't using + // dropout, in which case we use an empty array to save shared memory + using ZijSharedStorage = typename cutlass::platform::conditional< + kApplyDropout, + typename MatmulQK::AccumulatorSharedStorage, + // dummy shared storage object that takes up no space. + typename cutlass::gemm::threadblock::AccumulatorSharedStorage< +#ifdef _WIN32 + // windows builds throw the error: + // "type containing an unknown-size array is not allowed" + // if we try to make Zij shared storage zero-sized. + // To get around this just make it sized 1 on windows. + typename cutlass::gemm::GemmShape<1, 1, 0>, +#else + typename cutlass::gemm::GemmShape<0, 0, 0>, +#endif + typename MatmulQK::AccumulatorSharedStorage::Element, + typename MatmulQK::AccumulatorSharedStorage::Layout, + typename cutlass::MatrixShape<0, 0>>>::type; + + struct SharedStoragePrologue { + struct { + cutlass::Array di; // (do_i * o_i).sum(-1) + typename MatmulQK::Mma::SharedStorageA mm_qk_k; + } persistent; + union { + struct { + // part1 - after Q.K / dV / dO.V + union { + // 1. efficient load of bias tile Bij, which is then applied to Pij + typename MatmulQK::BiasLoader::SmemTile bias; + // 4. store Pij. it is needed: + // - in dVj += (Pij.T * Zij) @ dOi + // - in dSij = Pij * (dPij - Di) + // 6. dVj += (Pij.T * Zij) @ dOi + // 10. write to fragment + typename MatmulQK::AccumulatorSharedStorage attn_shared_storage; + }; + // 5. store Zij. it is needed in dVj += (Pij.T * Zij) @ dOi + ZijSharedStorage zij; + + union { + // 2. prologue for dVj + // 6. workspace for dVj += (Pij.T * Zij) @ dOi + typename MatmulGradV::Mma::SharedStorage mm_gradV; + // 7. dVj epilogue + typename MatmulGradV::DefaultEpilogue::SharedStorage gradV_epilogue; + }; + + // 3. prologue for dPij_dropped + // 8. used in dPij_dropped = dOi @ Vj.T + typename MatmulDOIVJ::Mma::SharedStorage mm_doivj; + } part1; + + struct { + // part2 - dQ + union { + typename MatmulQK::AccumulatorSharedStorage + tmpT_shared_storage; // (from part1) + typename MatmulDOIVJ::AccumulatorSharedStorage tmp_shared_storage; + }; + typename MatmulGradK::Mma::SharedStorage mm_gradK; // (preload) + typename MatmulGradQ::Mma::SharedStorage mm_gradQ; // (preload) + union { + // store dB = dSij to global memory + typename MatmulDOIVJ::BiasGradEpilogue::SharedStorage gradB_epilogue; + typename MatmulGradQ::DefaultEpilogue::SharedStorage gradQ_epilogue; + }; + + } part2; + + struct { + // part3 - after last iteration on dQ's epilogue / dK + union { + typename MatmulQK::AccumulatorSharedStorage + tmpT_shared_storage; // (from part1) + typename MatmulDOIVJ::AccumulatorSharedStorage tmp_shared_storage; + }; + typename MatmulGradK::Mma::SharedStorage mm_gradK; // (preload) + typename MatmulGradQ::DefaultEpilogue::SharedStorage + gradQ_epilogue_lastIter; + + typename MatmulGradK::DefaultEpilogue::SharedStorage gradK_epilogue; + } part3; + + struct { + // part4 - after last iteration on dK's epilogue / preload next K.Q_t + typename MatmulQK::Mma::SharedStorageB mm_qk_q; + + // If we reach end of current key, dump RF->gmem with "final" epilogues + typename MatmulGradK::DefaultEpilogue::SharedStorage + gradK_epilogue_final; + typename MatmulGradV::DefaultEpilogue::SharedStorage + gradV_epilogue_final; + } part4; + }; + static void print_size() { + // Field size +#define FSZ(f) int((sizeof(((SharedStoragePrologue*)0)->f))) + + printf("Total smem: %d bytes\n", int(sizeof(SharedStoragePrologue))); + printf(" persistent: %db\n", FSZ(persistent)); + printf(" mm_qk_k: %db\n", FSZ(persistent.mm_qk_k)); + printf(" part1: %db\n", FSZ(part1)); + printf(" bias: %db\n", FSZ(part1.bias)); + printf(" attn_shared_storage: %db\n", FSZ(part1.attn_shared_storage)); + printf(" zij: %db\n", FSZ(part1.zij)); + printf(" mm_gradV: %db\n", FSZ(part1.mm_gradV)); + printf(" gradV_epilogue: %db\n", FSZ(part1.gradV_epilogue)); + printf(" mm_doivj: %db\n", FSZ(part1.mm_doivj)); + printf(" part2: %db\n", FSZ(part2)); + printf(" tmpT_shared_storage: %db\n", FSZ(part2.tmpT_shared_storage)); + printf(" tmp_shared_storage: %db\n", FSZ(part2.tmp_shared_storage)); + printf(" mm_gradK: %db\n", FSZ(part2.mm_gradK)); + printf(" mm_gradQ: %db\n", FSZ(part2.mm_gradQ)); + printf(" gradB_epilogue: %db\n", FSZ(part2.gradB_epilogue)); + printf(" gradQ_epilogue: %db\n", FSZ(part2.gradQ_epilogue)); + printf(" part3: %db\n", FSZ(part3)); + printf(" tmpT_shared_storage: %db\n", FSZ(part3.tmpT_shared_storage)); + printf(" part4: %db\n", FSZ(part4)); + printf(" mm_qk_q: %db\n", FSZ(part4.mm_qk_q)); + printf( + " gradK_epilogue_final: %db\n", FSZ(part4.gradK_epilogue_final)); + printf( + " gradV_epilogue_final: %db\n", FSZ(part4.gradV_epilogue_final)); + } +// =========================================== +#define FIELD(INSIDE_STRUCT, FIELDNAME) \ + CUTLASS_DEVICE auto& FIELDNAME() { \ + return INSIDE_STRUCT.FIELDNAME; \ + } + + FIELD(persistent, di) + FIELD(persistent, mm_qk_k) + FIELD(part1, bias) + FIELD(part1, attn_shared_storage) + FIELD(part1, zij) + FIELD(part1, mm_gradV) + FIELD(part1, gradV_epilogue) + FIELD(part1, mm_doivj) + FIELD(part2, mm_gradK) + FIELD(part2, mm_gradQ) + FIELD(part2, gradB_epilogue) + FIELD(part2, gradQ_epilogue) + FIELD(part2, tmp_shared_storage) + FIELD(part3, tmpT_shared_storage) + FIELD(part3, gradQ_epilogue_lastIter) + FIELD(part3, gradK_epilogue) + FIELD(part4, mm_qk_q) + FIELD(part4, gradK_epilogue_final) + FIELD(part4, gradV_epilogue_final) + }; + + struct SharedStorageNoPrologue { + struct { + cutlass::Array di; // (do_i * o_i).sum(-1) + } persistent; + union { + struct { + // part1 - Q.K matmul + typename MatmulQK::Mma::SharedStorageA mm_qk_k; + typename MatmulQK::Mma::SharedStorageB mm_qk_q; + } part1; + + struct { + // part2 - compute gradV + union { + // 1. efficient load of bias tile Bij, which is then applied to Pij + typename MatmulQK::BiasLoader::SmemTile bias; + // 2. store Pij to shared memory. it is needed: + // - in this step, where it is used in dVj += (Pij.T * Zij) @ dOi + // - in next step where it is used in dSij = Pij * (dPij - Di) + typename MatmulQK::AccumulatorSharedStorage attn_shared_storage; + }; + // 3. store Zij. it is needed in this step, where it is used + // to compute Pij_dropped = Pij * Zij on the fly as fragments of Pij are + // loaded for the computation of dVj. + ZijSharedStorage zij; + + union { + typename MatmulGradV::Mma::SharedStorage mm_gradV; + typename MatmulGradV::DefaultEpilogue::SharedStorage gradV_epilogue; + }; + } part2; + + struct { + // part3 - DO.V matmul + union { + // first compute dPij = (dOi @ Vj.T) * Zij + // and dSij = Pij * (dPij - Di) + struct { + // (from part2) - Pij for computing dSij = Pij * (dPij - Di) + typename MatmulQK::AccumulatorSharedStorage attn_shared_storage; + // matmul to compute dOiVj + typename MatmulDOIVJ::Mma::SharedStorage mm_doivj; + }; + // then store dB = dSij to global memory + typename MatmulDOIVJ::BiasGradEpilogue::SharedStorage gradB_epilogue; + }; + } part3; + + struct { + // part4 - compute gradQ + typename MatmulQK::AccumulatorSharedStorage + tmpT_shared_storage; // (from part2) + typename MatmulDOIVJ::AccumulatorSharedStorage tmp_shared_storage; + union { + typename MatmulGradQ::Mma::SharedStorage mm_gradQ; + typename MatmulGradQ::DefaultEpilogue::SharedStorage gradQ_epilogue; + typename MatmulGradQ::DefaultEpilogue::SharedStorage + gradQ_epilogue_lastIter; + }; + } part4; + + struct { + // part5 - compute gradK + typename MatmulQK::AccumulatorSharedStorage + tmpT_shared_storage; // (from part2) + typename MatmulDOIVJ::AccumulatorSharedStorage tmp_shared_storage; + union { + typename MatmulGradK::Mma::SharedStorage mm_gradK; + typename MatmulGradK::DefaultEpilogue::SharedStorage gradK_epilogue; + }; + } part5; + + struct { + // part6 - store RF accumulated into gmem + typename MatmulGradK::DefaultEpilogue::SharedStorage + gradK_epilogue_final; + typename MatmulGradV::DefaultEpilogue::SharedStorage + gradV_epilogue_final; + } part6; + }; + static void print_size() { +#define FIELD_SIZEOF(f) int((sizeof(((SharedStorageNoPrologue*)0)->f))) + printf("Total smem: %d bytes\n", int(sizeof(SharedStorageNoPrologue))); + printf(" persistent: %db\n", FIELD_SIZEOF(persistent)); + printf(" part1: %db\n", FIELD_SIZEOF(part1)); + printf(" part2: %db\n", FIELD_SIZEOF(part2)); + printf(" part3: %db\n", FIELD_SIZEOF(part3)); + printf(" part4: %db\n", FIELD_SIZEOF(part4)); + printf(" part5: %db\n", FIELD_SIZEOF(part5)); + printf(" part6: %db\n", FIELD_SIZEOF(part6)); + } +// =========================================== +#define FIELD(INSIDE_STRUCT, FIELDNAME) \ + CUTLASS_DEVICE auto& FIELDNAME() { \ + return INSIDE_STRUCT.FIELDNAME; \ + } + + FIELD(persistent, di) + FIELD(part1, mm_qk_k) + FIELD(part1, mm_qk_q) + FIELD(part2, bias) + FIELD(part2, attn_shared_storage) + FIELD(part2, zij) + FIELD(part2, mm_gradV) + FIELD(part2, gradV_epilogue) + FIELD(part3, mm_doivj) + FIELD(part3, gradB_epilogue) + FIELD(part4, tmpT_shared_storage) + FIELD(part4, tmp_shared_storage) + FIELD(part4, mm_gradQ) + FIELD(part4, gradQ_epilogue) + FIELD(part4, gradQ_epilogue_lastIter) + FIELD(part5, mm_gradK) + FIELD(part5, gradK_epilogue) + FIELD(part6, gradK_epilogue_final) + FIELD(part6, gradV_epilogue_final) + }; + + using SharedStorage = typename cutlass::platform::conditional< + kPreload, + SharedStoragePrologue, + SharedStorageNoPrologue>::type; + + struct OutputFragments { + typename MatmulGradV::Mma::FragmentC gradV; + typename MatmulGradK::Mma::FragmentC gradK; + + CUTLASS_DEVICE void clear() { + gradV.clear(); + gradK.clear(); + } + }; + + static bool __host__ check_supported(Params const& p) { + CHECK_ALIGNED_PTR(p.query_ptr, kMinimumAlignment); + CHECK_ALIGNED_PTR(p.key_ptr, kMinimumAlignment); + CHECK_ALIGNED_PTR(p.value_ptr, kMinimumAlignment); + CHECK_ALIGNED_PTR(p.output_ptr, kMinimumAlignment); + CHECK_ALIGNED_PTR(p.grad_output_ptr, kMinimumAlignment); + CHECK_ALIGNED_PTR(p.bias_ptr, kMinimumAlignment); + XFORMERS_CHECK(p.lse_strideH % 8 == 0, "LSE is not correctly aligned"); + XFORMERS_CHECK(p.lse_strideB % 8 == 0, "LSE is not correctly aligned"); + XFORMERS_CHECK( + p.num_heads <= 1 || p.q_strideH % kMinimumAlignment == 0, + "query is not correctly aligned (strideH)"); + XFORMERS_CHECK( + p.num_heads <= 1 || p.k_strideH % kMinimumAlignment == 0, + "key is not correctly aligned (strideH)"); + XFORMERS_CHECK( + p.num_heads <= 1 || p.v_strideH % kMinimumAlignment == 0, + "value is not correctly aligned (strideH)"); + XFORMERS_CHECK( + p.num_batches <= 1 || p.q_strideB % kMinimumAlignment == 0, + "query is not correctly aligned (strideB)"); + XFORMERS_CHECK( + p.num_batches <= 1 || p.k_strideB % kMinimumAlignment == 0, + "key is not correctly aligned (strideB)"); + XFORMERS_CHECK( + p.num_batches <= 1 || p.v_strideB % kMinimumAlignment == 0, + "value is not correctly aligned (strideB)"); + XFORMERS_CHECK( + p.q_strideM % kMinimumAlignment == 0, + "query is not correctly aligned (strideM)"); + XFORMERS_CHECK( + p.k_strideM % kMinimumAlignment == 0, + "key is not correctly aligned (strideM)"); + XFORMERS_CHECK( + p.v_strideM % kMinimumAlignment == 0, + "value is not correctly aligned (strideM)"); + if (p.bias_ptr) { + XFORMERS_CHECK( + p.num_batches <= 1 || p.bias_strideB % kMinimumAlignment == 0, + "attn_bias is not correctly aligned (strideB)"); + XFORMERS_CHECK( + p.num_heads <= 1 || p.bias_strideH % kMinimumAlignment == 0, + "attn_bias is not correctly aligned (strideH)"); + XFORMERS_CHECK( + p.bias_strideM % kMinimumAlignment == 0, + "attn_bias is not correctly aligned (strideM)"); + } + if (p.grad_bias_ptr) { + XFORMERS_CHECK( + p.num_batches <= 1 || p.gB_strideB % kMinimumAlignment == 0, + "attn_bias.grad is not correctly aligned (strideB)"); + XFORMERS_CHECK( + p.num_heads <= 1 || p.gB_strideH % kMinimumAlignment == 0, + "attn_bias.grad is not correctly aligned (strideH)"); + XFORMERS_CHECK( + p.gB_strideM % kMinimumAlignment == 0, + "attn_bias.grad is not correctly aligned (strideM)"); + } + XFORMERS_CHECK( + !(p.cu_seqlens_q_ptr && p.bias_ptr), + "CuSeqlen + bias not implemented yet"); + XFORMERS_CHECK( + p.custom_mask_type < NumCustomMaskTypes, + "Invalid value for `custom_mask_type`"); + XFORMERS_CHECK( + p.dropout_prob <= 1.0f && p.dropout_prob >= 0.0f, + "Invalid value for `dropout_prob`"); + XFORMERS_CHECK( + kApplyDropout || p.dropout_prob == 0.0f, + "Set `kApplyDropout`=True to support `dropout_prob > 0`"); + XFORMERS_CHECK(p.head_dim > 0, "Invalid value for `head_dim`"); + XFORMERS_CHECK(p.head_dim_value > 0, "Invalid value for `head_dim_value`"); + XFORMERS_CHECK(p.num_queries > 0, "Invalid value for `num_queries`"); + XFORMERS_CHECK(p.num_keys > 0, "Invalid value for `num_keys`"); + XFORMERS_CHECK(p.num_heads > 0, "Invalid value for `num_heads`"); + XFORMERS_CHECK(p.num_batches > 0, "Invalid value for `num_batches`"); + XFORMERS_CHECK(p.head_dim <= kMaxK, "kMaxK: Expected `head_dim < kMaxK`"); + XFORMERS_CHECK( + p.head_dim_value <= kMaxK, "kMaxK: Expected `head_dim_value < kMaxK`"); + if (kKeysQueriesAlignedToBlockSize) { + XFORMERS_CHECK( + p.cu_seqlens_k_ptr == nullptr, + "This kernel does not support cu_seqlen"); + XFORMERS_CHECK( + p.cu_seqlens_q_ptr == nullptr, + "This kernel does not support cu_seqlen"); + XFORMERS_CHECK( + p.num_queries % kBlockSizeI == 0, + "kKeysQueriesAlignedToBlockSize condition not respected"); + XFORMERS_CHECK( + p.num_keys % kBlockSizeJ == 0, + "kKeysQueriesAlignedToBlockSize condition not respected"); + } + XFORMERS_CHECK( + kEnableSplitKeys || p.num_splits_key == 1, "SplitKeys is disabled"); + XFORMERS_CHECK( + p.num_splits_key > 0, "Invalid `num_splits_key` (expected >0)"); + XFORMERS_CHECK( + p.num_splits_key <= cutlass::ceil_div(p.num_keys, kBlockSizeJ), + "Invalid `num_splits_key` (too large)"); + return true; + } + + static CUTLASS_DEVICE void attention_kernel(Params p) { + extern __shared__ char smem_buffer[]; + SharedStorage& shared_storage = *((SharedStorage*)smem_buffer); + + uint16_t thread_id = threadIdx.x; + uint8_t warp_id = warp_uniform(thread_id / 32); + uint8_t lane_id = thread_id % 32; + + int32_t key_start = p.split_key_device() * kBlockSizeJ; + if (key_start >= p.num_keys) { + return; + } + if (kPrologueQK) { + int32_t query_start = getQueryStart(p, key_start); + prologueQkNextIteration( + shared_storage, p, query_start, key_start, warp_id, lane_id); + } + + // Computes (dO*out).sum(-1) and writes it to `p.delta_ptr` + if (kKernelComputesDelta) { + constexpr int kOptimalElements = + 128 / cutlass::sizeof_bits::value; + if (p.head_dim_value % kOptimalElements == 0) { + for (int query_start = 0; query_start < p.num_queries; + query_start += kBlockSizeI) { + computeDelta(p, query_start, warp_id, lane_id); + } + } else { + for (int query_start = 0; query_start < p.num_queries; + query_start += kBlockSizeI) { + computeDelta<1>(p, query_start, warp_id, lane_id); + } + } + __syncthreads(); + } + + OutputFragments output_frags; + + curandStatePhilox4_32_10_t rng_state_init; +#ifdef HAS_PYTORCH + if (kApplyDropout) { + auto seeds = at::cuda::philox::unpack(p.rng_engine_inputs); + // each element of the attention matrix P with shape + // (batch_sz, n_heads, n_queries, n_keys) is associated with a single + // offset in RNG sequence. we initialize the RNG state with offset that + // starts at the beginning of a (n_queries, n_keys) matrix for this + // block's batch_id and head_id + // initializing rng state is very expensive, so we run once per kernel, + // rather than once per iteration. each iteration takes a copy of the + // initialized RNG state and offsets it as needed. + curand_init( + std::get<0>(seeds), + 0, + std::get<1>(seeds) + p.dropout_batch_head_rng_offset, + &rng_state_init); + } +#endif + CUTLASS_PRAGMA_UNROLL + for (; key_start < p.num_keys; + key_start += p.num_splits_key_device() * kBlockSizeJ) { + output_frags.clear(); + + CUTLASS_PRAGMA_UNROLL + for (int32_t query_start_shifted = getQueryStart(p, key_start); + query_start_shifted < getQueryStartShift(p) + getQueryEnd(p); + query_start_shifted += kBlockSizeI) { + // This line here + // vvvvvvvvvvvvvv + warp_id = warp_uniform(warp_id); + // ^^^^^^^^^^^^^^ + // ... makes everything use less RF and be 10% faster. Why? + // I don't know. My theory is that it forces `nvcc` to + // re-compute indices, offsets etc... and not keep them + // from the previous iteration, which prevents MASSIVE + // register spilling. + + int32_t query_start = query_start_shifted; + if (query_start >= p.num_queries) { + query_start = query_start % getQueryEnd(p); + } + + processBlockIJ( + shared_storage, + output_frags, + p, + query_start, + key_start, + rng_state_init, + warp_id, + lane_id); + } + if (kOutputInRF) { + writeFragsToGmem( + shared_storage, output_frags, p, key_start, warp_id, lane_id); + } else if (getQueryStart(p, key_start) >= p.num_queries) { + zfillGradKV( + p, key_start, warp_id, lane_id); + } + __syncthreads(); + } + } + + template + static CUTLASS_DEVICE void zfillGradKV( + Params const& p, + int32_t key_start, + uint8_t warp_id, + uint8_t lane_id) { + constexpr int kThreadsPerKey = 8; + constexpr int kParallelKeys = kNumThreads / kThreadsPerKey; + static_assert(kBlockSizeJ % kParallelKeys == 0, ""); + // This function is not really optimized, but should rarely be used + // It's only used when some keys are "useless" and don't attend to + // any query, due to causal masking + + int thread_id = 32 * warp_id + lane_id; + int k_shift = lane_id % kThreadsPerKey; + + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < kBlockSizeJ; j += kParallelKeys) { + int key = key_start + j + (thread_id / kThreadsPerKey); + if (!skipBoundsChecks && key >= p.num_keys) { + continue; + } + auto gv_ptr = p.grad_value_ptr + key * p.gV_strideM(); + auto gk_ptr = p.grad_key_ptr + key * p.gK_strideM(); + + for (int k = k_shift; k < p.head_dim_value; k += kThreadsPerKey) { + gv_ptr[k] = scalar_t(0); + } + for (int k = k_shift; k < p.head_dim; k += kThreadsPerKey) { + gk_ptr[k] = scalar_t(0); + } + } + } + + template + static CUTLASS_DEVICE void processBlockIJ( + SharedStorage& shared_storage, + OutputFragments& output_frags, + Params& p, + int32_t query_start, + int32_t key_start, + const curandStatePhilox4_32_10_t& curand_state_init, + uint8_t warp_id, + uint8_t lane_id) { + cutlass::Array + dropout_keep_mask_doivj; + dropout_keep_mask_doivj.fill(cutlass::uint1b_t{1}); + const float dropout_scale = + kApplyDropout ? 1.0 / (1.0 - p.dropout_prob) : 1.0f; + + cutlass::MatrixCoord no_offset{0, 0}; + accum_t scale = p.scale; + int16_t thread_id = 32 * warp_id + lane_id; + + auto rematerializeThreadIds = [&]() { + // Prevents `nvcc` from keeping values deduced from + // `thread_id`, `warp_id`, ... in RF - to reduce register pressure + warp_id = warp_uniform(thread_id / 32); + lane_id = thread_id % 32; + thread_id = 32 * warp_id + lane_id; + }; + + bool isFirstQuery = (query_start == getQueryStart(p, key_start)); + int32_t next_query, next_key; + incrIteration(p, query_start, key_start, next_query, next_key); + bool isLastQuery = next_key != key_start; + + accum_t di_rf = accum_t(0); + if (thread_id < kBlockSizeI) { + if (query_start + thread_id < p.num_queries) { + di_rf = p.delta_ptr[query_start + thread_id]; + } + shared_storage.di()[thread_id] = di_rf; + } + + int32_t num_queries_in_block = skipBoundsChecks + ? MatmulQK::Mma::Shape::kN + : warp_uniform(cutlass::fast_min( + (int32_t)MatmulQK::Mma::Shape::kN, p.num_queries - query_start)); + int32_t num_keys_in_block = skipBoundsChecks + ? MatmulQK::Mma::Shape::kM + : warp_uniform(cutlass::fast_min( + (int32_t)MatmulQK::Mma::Shape::kM, p.num_keys - key_start)); + + auto prologueGradV = [&](int col) { + typename MatmulGradV::Mma::IteratorB iterator_dO( + {int32_t(p.gO_strideM)}, + p.grad_output_ptr + query_start * p.gO_strideM + col, + {num_queries_in_block, p.head_dim_value - col}, + thread_id, + no_offset); + MatmulGradV::Mma::prologue( + shared_storage.mm_gradV(), + iterator_dO, + thread_id, + num_queries_in_block); + }; + auto prologueGradQ = [&](int col) { + typename MatmulGradQ::Mma::IteratorB iterator_K( + {int32_t(p.k_strideM)}, + p.key_ptr + key_start * p.k_strideM + col, + {num_keys_in_block, p.head_dim - col}, + thread_id, + no_offset); + MatmulGradQ::Mma::prologue( + shared_storage.mm_gradQ(), iterator_K, thread_id, num_keys_in_block); + }; + auto prologueGradK = [&](int col) { + typename MatmulGradK::Mma::IteratorB iterator_Q( + {int32_t(p.q_strideM)}, + p.query_ptr + query_start * p.q_strideM + col, + {num_queries_in_block, p.head_dim - col}, + thread_id, + no_offset); + MatmulGradK::Mma::prologue( + shared_storage.mm_gradK(), + iterator_Q, + thread_id, + num_queries_in_block); + }; + auto prologueDOV = [&]() { + typename MatmulDOIVJ::Mma::IteratorA iterator_A( + {int32_t(p.gO_strideM)}, + p.grad_output_ptr + query_start * p.gO_strideM, + {num_queries_in_block, p.head_dim_value}, + thread_id, + no_offset); + typename MatmulDOIVJ::Mma::IteratorB iterator_B( + {int32_t(p.v_strideM)}, + p.value_ptr + key_start * p.v_strideM, + {p.head_dim_value, num_keys_in_block}, + thread_id, + no_offset); + MatmulDOIVJ::Mma::prologue( + shared_storage.mm_doivj(), + iterator_A, + iterator_B, + thread_id, + p.head_dim_value); + }; + + ///////////////////////////////////////////////////////////////////////////////////////////////// + // MatmulQK + ///////////////////////////////////////////////////////////////////////////////////////////////// + { + using Mma = typename MatmulQK::Mma; + + cutlass::gemm::GemmCoord problem_size( + num_keys_in_block, + num_queries_in_block, + p.head_dim // k + ); + + // k_j + typename Mma::IteratorA iterator_A( + {int32_t(p.k_strideM)}, + p.key_ptr + key_start * p.k_strideM, + {problem_size.m(), problem_size.k()}, + thread_id, + no_offset); + + // q_i.transpose(-2, -1) + typename Mma::IteratorB iterator_B( + {int32_t(p.q_strideM)}, + p.query_ptr + query_start * p.q_strideM, + {problem_size.k(), problem_size.n()}, + thread_id, + no_offset); + + Mma mma( + shared_storage.mm_qk_k(), + shared_storage.mm_qk_q(), + thread_id, + warp_id, + lane_id); + + typename Mma::FragmentC accum; + + accum.clear(); + + auto gemm_k_iterations = + (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma.set_prologue_done(kPrologueQK); + mma.set_zero_outside_bounds(!skipBoundsChecks); + mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); + accum = cutlass::multiplies()(scale, accum); + + // Epilogue: add LSE + exp and store that to our shared memory buffer + // shmem <- (matmul_result - + // logsumexp[i_start:i_end].unsqueeze(1)).exp() + int warp_idx_mn_0 = + warp_id % (Mma::Base::WarpCount::kM * Mma::Base::WarpCount::kN); + auto output_tile_coords = cutlass::MatrixCoord{ + warp_idx_mn_0 % Mma::Base::WarpCount::kM, + warp_idx_mn_0 / Mma::Base::WarpCount::kM}; + + // apply bias if applicable + if (p.bias_ptr != nullptr) { + // load bias tile Bij into shared memory + typename MatmulQK::BiasLoader::GmemTileIterator bias_iter( + {cutlass::layout::RowMajor(p.bias_strideM)}, + p.bias_ptr + query_start * p.bias_strideM + key_start, + {num_queries_in_block, num_keys_in_block}, + thread_id); + cutlass::TensorRef bias_tensor_ref( + shared_storage.bias().data(), + cutlass::layout::RowMajor(MatmulQK::ThreadblockShape::kM)); + typename MatmulQK::BiasLoader::SmemTileIterator smem_tile_iter( + bias_tensor_ref, thread_id); + MatmulQK::BiasLoader::load(bias_iter, smem_tile_iter); + + // Pij += Bij, where Pij is in register fragment and Bij is in shmem + auto lane_offset = MatmulQK::AccumLambdaIterator::get_lane_offset( + lane_id, warp_id, output_tile_coords); + MatmulQK::AccumLambdaIterator::iterateRows( + lane_offset, + [&](int accum_n) {}, + [&](int accum_m, int accum_n, int idx) { + // remember we are transposed + accum[idx] += bias_tensor_ref.at({accum_n, accum_m}); + }, + [&](int accum_n) {}); + } + + // Apply mask + if (p.custom_mask_type == CausalFromTopLeft || + p.custom_mask_type == CausalFromBottomRight) { + auto lane_offset = MatmulQK::AccumLambdaIterator::get_lane_offset( + lane_id, warp_id, output_tile_coords); + int shift = query_start - key_start; + if (p.custom_mask_type == CausalFromBottomRight) { + shift += p.num_keys - p.num_queries; + } + // current_key = key_start + accum_m + // current_query = query_start + accum_n + // mask if: `current_key > current_query` + MatmulQK::AccumLambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) {}, + [&](int accum_m, int accum_n, int idx) { + if (accum_m > accum_n + shift) { + accum[idx] = + -cutlass::platform::numeric_limits::infinity(); + } + }, + [&](int accum_m) {}); + } + + __syncthreads(); + if (kPrologueGV) { + prologueGradV(0); + } + if (kPrologueDOV) { + prologueDOV(); + } + + MatmulQK::B2bGemm::accumApplyLSEToSmem( + shared_storage.attn_shared_storage(), + accum, + p.logsumexp_ptr + query_start, + problem_size.n(), + thread_id, + warp_id, + lane_id, + output_tile_coords); +#if 0 + auto accum_ref_attnT = shared_storage.attn_shared_storage().accum_ref(); + PRINT_TENSOR4x4_T0_L0("attn_T", accum_ref_attnT); +#endif + + // if we are using dropout, compute Zij, writing it to shared memory. + // each element of Zij is: + // - 0 with probability dropout_p + // - 1 / (1 - dropout_p) with probability 1 - dropout_p + if (kApplyDropout) { + auto zij = shared_storage.zij().accum_ref(); + // each thread generates a contiguous sequence of elements in Zij, all + // in the same row. the reason they have to come from the same row is + // that sampling random numbers from a contiguous random number sequence + // is much more efficient than jumping around, and the linear offset of + // each element of Z (the global matrix) maps to an offset in a random + // number sequence. for Z, the end of a row and the beginning of the + // next have adjacent offsets, but for Zij (tile of global matrix), this + // is not necessarily the case. + // We must fill the entire `zij` shmem with values (even out of bounds + // on the K-dimension) otherwise we can get NaNs during the GEMM + const int kQueriesPerBlock = kBlockSizeI; + const int threads_per_row = cutlass::fast_min( + int32_t(kNumThreads / kQueriesPerBlock), num_keys_in_block); + const int elts_per_thread = cutlass::round_nearest( + cutlass::ceil_div(num_keys_in_block, threads_per_row), 4); + + const int thread_i = thread_id / threads_per_row; + const int thread_start_j = + (thread_id % threads_per_row) * elts_per_thread; + + if (thread_i < kQueriesPerBlock && thread_start_j < num_keys_in_block) { + curandStatePhilox4_32_10_t curand_state = curand_state_init; + skipahead( + (query_start + thread_i) * p.num_keys + + (key_start + thread_start_j), + &curand_state); + + // generate elements of Zij, 4 elements at a time + for (int zij_start_col_idx = thread_start_j; zij_start_col_idx < + cutlass::fast_min(thread_start_j + elts_per_thread, + num_keys_in_block); + zij_start_col_idx += 4) { + const float4 rand_uniform_quad = curand_uniform4(&curand_state); + + CUTLASS_PRAGMA_UNROLL + for (int quad_idx = 0; quad_idx < 4; ++quad_idx) { + // we'll write Zij transposed since attention is also transposed + // during the matmul to compute dV. + zij.at({zij_start_col_idx + quad_idx /*k*/, thread_i /*q*/}) = + (&rand_uniform_quad.x)[quad_idx] > p.dropout_prob + ? scalar_t(dropout_scale) + : scalar_t(0); + } + } + } + __syncthreads(); +#if 0 + PRINT_TENSOR4x4_T0_L0("zij", zij); + PRINT_TENSOR4x4_T0_L0_START("zij", zij, kBlockSizeJ - 4, kBlockSizeI - 4); +#endif + + // Save mask for later DOIVJ matmul + + int warp_idx_mn_0 = warp_id % + (MatmulDOIVJ::Mma::Base::WarpCount::kM * + MatmulDOIVJ::Mma::Base::WarpCount::kN); + auto output_tile_coords_doivj = cutlass::MatrixCoord{ + warp_idx_mn_0 % MatmulDOIVJ::Mma::Base::WarpCount::kM, + warp_idx_mn_0 / MatmulDOIVJ::Mma::Base::WarpCount::kM}; + auto lane_offset = MatmulDOIVJ::AccumLambdaIterator::get_lane_offset( + lane_id, warp_id, output_tile_coords_doivj); + MatmulDOIVJ::AccumLambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) {}, + [&](int accum_m /*q*/, int accum_n /*k*/, int idx) { + if (zij.at({accum_n, accum_m}) == scalar_t(0)) { + dropout_keep_mask_doivj[idx] = cutlass::uint1b_t{0}; + } + }, + [&](int accum_m) {}); + } + __syncthreads(); + } + rematerializeThreadIds(); + + ///////////////////////////////////////////////////////////////////////////////////////////////// + // GradV matmul + // + // grad_v[j_start:j_end] += attn_T @ do_i + ///////////////////////////////////////////////////////////////////////////////////////////////// + constexpr bool kSingleIterationGradV = + kMaxK <= MatmulGradV::ThreadblockShape::kN; + for (int col = 0; col < (kSingleIterationGradV ? 1 : p.head_dim_value); + col += MatmulGradV::ThreadblockShape::kN) { + using Mma = typename MatmulGradV::Mma; + using AccumTileGmem = typename MatmulGradQ::AccumTileGmem; + + cutlass::gemm::GemmCoord problem_size( + num_keys_in_block, p.head_dim_value - col, num_queries_in_block); + auto createEpilogueIter = [&]() { + return typename MatmulGradV::OutputTileIterator( + typename MatmulGradV::OutputTileIterator::Params{p.gV_strideM()}, + p.grad_value_ptr + key_start * p.gV_strideM() + col, + {num_keys_in_block, p.head_dim_value - col}, + thread_id); + }; + typename Mma::IteratorB iterator_B( + {int32_t(p.gO_strideM)}, + p.grad_output_ptr + query_start * p.gO_strideM + col, + {num_queries_in_block, p.head_dim_value - col}, + thread_id, + no_offset); + + // if dropout: dVj += (Pij.T * Zij) @ dOi + // otherwise: dVj += Pij.T @ dOi + Mma mma( + // operand A: Pij.T + shared_storage.attn_shared_storage().accum_ref(), + // operand A_scale Zij.T: + // if we're using dropout, operand A is Pij_dropped.T = Pij.T * Zij.T + // which is computed on the fly as fragments of Pij.T are loaded in + shared_storage.zij().accum_ref(), + // operand B: dOi - which was loaded into shared memory previously + // when we computed dVj + shared_storage.mm_gradV().operand_B_ref(), + thread_id, + warp_id, + lane_id); + + int storage_id = col / MatmulGradV::ThreadblockShape::kN; + AccumTileGmem gmem_tile{ + p.workspace_gv + storage_id * AccumTileGmem::kElementsStored}; + if (!kOutputInRF) { + if (isFirstQuery || !kNeedsAccumGradV) { + output_frags.gradV.clear(); + } else { + gmem_tile.load(output_frags.gradV, thread_id); + } + } + mma.set_prologue_done(kPrologueGV); + + auto gemm_k_iterations = + (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + __syncthreads(); + + mma(gemm_k_iterations, + output_frags.gradV, + iterator_B, + output_frags.gradV); + __syncthreads(); + if (kPrologueGV && !kSingleIterationGradV && + col + MatmulGradV::ThreadblockShape::kN < p.head_dim_value) { + prologueGradV(col + MatmulGradV::ThreadblockShape::kN); + } + + if (!kOutputInRF) { + if (kNeedsAccumGradV && !isLastQuery) { + gmem_tile.store(output_frags.gradV, thread_id); + } else { + accumulateInGmem( + shared_storage.gradV_epilogue(), + output_frags.gradV, + createEpilogueIter(), + isFirstQuery || kNeedsAccumGradV, + warp_id, + lane_id); + } + } + } + __syncthreads(); + + ///////////////////////////////////////////////////////////////////////////////////////////////// + // MatmulDOIVJ + ///////////////////////////////////////////////////////////////////////////////////////////////// + { + using Mma = typename MatmulDOIVJ::Mma; + // do_i + typename Mma::IteratorA iterator_A( + {int32_t(p.gO_strideM)}, + p.grad_output_ptr + query_start * p.gO_strideM, + {num_queries_in_block, p.head_dim_value}, + thread_id, + no_offset); + + // v_j.transpose(-2, -1) + typename Mma::IteratorB iterator_B( + {int32_t(p.v_strideM)}, + p.value_ptr + key_start * p.v_strideM, + {p.head_dim_value, num_keys_in_block}, + thread_id, + no_offset); + + Mma mma(shared_storage.mm_doivj(), thread_id, warp_id, lane_id); + mma.set_prologue_done(kPrologueDOV); + mma.set_zero_outside_bounds(!skipBoundsChecks); + + typename Mma::FragmentC accum; + + accum.clear(); + + auto gemm_k_iterations = + (p.head_dim_value + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); + __syncthreads(); + if (kPrologueGQ) { + prologueGradQ(0); + } + if (kPrologueGK) { + prologueGradK(0); + } + + int warp_idx_mn_0 = + warp_id % (Mma::Base::WarpCount::kM * Mma::Base::WarpCount::kN); + auto output_tile_coords = cutlass::MatrixCoord{ + warp_idx_mn_0 % Mma::Base::WarpCount::kM, + warp_idx_mn_0 / Mma::Base::WarpCount::kM}; + // TODO: This must be terribly inefficient. There must be a better way + // tmp [RF] <- (accum [RF] - Di [smem] ) * attn_T.T [smem] + // attn_shared_storage [smem] <- tmp.T + // tmp_shared_storage [smem] <- tmp + { + using LambdaIterator = typename MatmulDOIVJ::AccumLambdaIterator; + auto lane_offset = LambdaIterator::get_lane_offset( + lane_id, warp_id, output_tile_coords); + // if dropout was used, compute dPij = dPij_dropped * Zij + if (kApplyDropout) { + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) {}, + [&](int accum_m, int accum_n, int idx) { + if (dropout_keep_mask_doivj[idx].get()) { + accum[idx] *= dropout_scale; + } else { + accum[idx] = 0; + } + }, + [&](int accum_m) {}); + } + + auto attn_T = shared_storage.attn_shared_storage().accum_ref(); +#if 0 + PRINT_B0_T0("doivj_dropped"); + print_warp_accum(accum, lane_offset, 4, 4); + PRINT_TENSOR4x4_T0_L0("attn_T", attn_T) +#endif + accum_t current_di; + // dSij = (dPij - Di) * Pij + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { current_di = shared_storage.di()[accum_m]; }, + [&](int accum_m, int accum_n, int idx) { + // TODO: Otherwise we can get nans as we + // might have infs here (only seen on f16 tho) + if (skipBoundsChecks || + (accum_m < num_queries_in_block && + accum_n < num_keys_in_block)) { + accum_t attn = attn_T.at({accum_n, accum_m}); + accum[idx] = (accum[idx] - current_di) * attn; + } else { + accum[idx] = 0; + } + }, + [&](int accum_m) { + + }); + + // store bias gradient tile dBij to global memory, + // where dBij = dSij = Pij * (dPij - Di) + if (p.grad_bias_ptr != nullptr) { + typename MatmulDOIVJ::BiasGradEpilogue::OutputTileIterator + output_iter( + typename MatmulDOIVJ::BiasGradEpilogue::OutputTileIterator:: + Params{p.gB_strideM}, + // grad_bias_ptr is offset to point at beginning of + // matrix of shape (queries, keys) for a given + // (batch_id, head_id) the pointer arithmetic here produces + // a pointer to the start of the current tile within that + // matrix + p.grad_bias_ptr + query_start * p.gB_strideM + key_start, + {num_queries_in_block, num_keys_in_block}, + thread_id); + + // no-op epilogue operator - just casting and storing contents of + // accum to global memory + typename MatmulDOIVJ::BiasGradEpilogue::OutputOp output_op( + typename MatmulDOIVJ::BiasGradEpilogue::OutputOp::Params{1, 1}); + typename MatmulDOIVJ::BiasGradEpilogue epilogue( + shared_storage.gradB_epilogue(), thread_id, warp_id, lane_id); + epilogue(output_op, output_iter, accum, output_iter); + } + + accum = accum * scale; + +#if 0 + PRINT_B0_T0("(doivj - di) * attn * scale"); + print_warp_accum(accum, lane_offset, 4, 4); +#endif + + __syncthreads(); + if (!MatmulGradK::DefaultMmaFromSmem::kIsTransposedA) { + auto tmpT = shared_storage.tmpT_shared_storage().accum_ref(); + // attn <- attn_T.T + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) {}, + [&](int accum_m, int accum_n, int idx) { + tmpT.at({accum_n, accum_m}) = scalar_t(accum[idx]); + }, + [&](int accum_m) {}); + } + } + + MatmulDOIVJ::B2bGemm::accumToSmem( + shared_storage.tmp_shared_storage(), + accum, + lane_id, + output_tile_coords); + __syncthreads(); + } + // Force `nvcc` to recompute values that depend on the variables just below + // to use less RF and prevent some spilling + p.head_dim = warp_uniform(p.head_dim); + p.k_strideM = warp_uniform(p.k_strideM); + rematerializeThreadIds(); + + ///////////////////////////////////////////////////////////////////////////////////////////////// + // GradQ matmul + // + // grad_q[i_start:i_end] += tmp @ k_j + ///////////////////////////////////////////////////////////////////////////////////////////////// + // Skip the loop & associated branches if we know at compile time the number + // of iterations + constexpr bool kSingleIterationGradQ = + kMaxK <= MatmulGradQ::ThreadblockShape::kN; + for (int col = 0; col < (kSingleIterationGradQ ? 1 : p.head_dim); + col += MatmulGradQ::ThreadblockShape::kN) { + using Mma = typename MatmulGradQ::Mma; + using AccumTileGmem = typename MatmulGradQ::AccumTileGmem; + + cutlass::gemm::GemmCoord problem_size( + num_queries_in_block, + false ? MatmulGradQ::ThreadblockShape::kN : p.head_dim - col, + num_keys_in_block); + + // k_j + typename Mma::IteratorB iterator_B( + {int32_t(p.k_strideM)}, + p.key_ptr + key_start * p.k_strideM + col, + {problem_size.k(), problem_size.n()}, + thread_id, + no_offset); + + auto a = shared_storage.tmp_shared_storage().accum_ref(); + Mma mma( + // operand A: dSij + shared_storage.tmp_shared_storage().accum_ref(), + // operand B: Kj + shared_storage.mm_gradQ().operand_B_ref(), + thread_id, + warp_id, + lane_id); + + typename Mma::FragmentC accum; + + int col_id = col / MatmulGradQ::ThreadblockShape::kN; + int num_cols = kSingleIterationGradQ + ? 1 + : ceil_div(p.head_dim, MatmulGradQ::ThreadblockShape::kN); + int storage_id = (col_id + query_start / kBlockSizeI * num_cols); + + if (p.num_splits_key_device() > 1) { + AtomicLock::acquire( + &p.workspace_gq[storage_id].lock, + p.split_key_device() + 1, + thread_id); + // Make sure we can see other block's output + __threadfence(); + } + + AccumTileGmem gmem_tile{&p.workspace_gq[storage_id].buffer[0]}; + if (!kNeedsAccumGradQ || + (p.num_splits_key_device() == 1 && key_start == 0)) { + // if we know we are the first to access it, we know it's only zeros. + // Avoids a load from gmem (and gmem init as well) + accum.clear(); + } else { + gmem_tile.load(accum, thread_id); + } + + auto gemm_k_iterations = + (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + __syncthreads(); + mma.set_prologue_done(kPrologueGQ); + mma(gemm_k_iterations, accum, iterator_B, accum); + __syncthreads(); + bool isLastColumn = kSingleIterationGradQ || + (col + MatmulGradQ::ThreadblockShape::kN >= p.head_dim); + if (kPrologueGQ && !isLastColumn) { + prologueGradQ(col + MatmulGradQ::ThreadblockShape::kN); + } + + bool isLast = [&]() { + int32_t next_key = key_start + p.num_splits_key_device() * kBlockSizeJ; + if (p.num_keys <= next_key) { + return true; + } + if (query_start < getSmallestQueryForKey(p, next_key)) { + return true; + } + return false; + }(); + // Output results + if (p.num_splits_key_device() > 1) { + int32_t numAddsSoFar = -1; + if (isLast && thread_id == 0) { + numAddsSoFar = atomicAdd(&p.workspace_gq[storage_id].counter, 1) + + 1; // `atomicAdd` returns the old value + } + isLast = __syncthreads_or( + numAddsSoFar == getNumParallelBlocksForQuery(p, query_start)); + assert(numAddsSoFar <= getNumParallelBlocksForQuery(p, query_start)); + } + if (kNeedsAccumGradQ && !isLast) { + gmem_tile.store(accum, thread_id); + if (p.num_splits_key_device() > 1) { + // Make sure everyone wrote before we release the lock + __threadfence(); + __syncthreads(); + AtomicLock::release(&p.workspace_gq[storage_id].lock, thread_id); + } + } else { + // NOTE: We're not releasing the lock because no one is expected + // to come after us (we're the last one to write) + typename MatmulGradQ::OutputTileIterator output_it( + typename MatmulGradQ::OutputTileIterator::Params{p.gQ_strideM()}, + p.grad_query_ptr + query_start * p.gQ_strideM() + col, + {problem_size.m(), problem_size.n()}, + thread_id); + bool storage_contains_zeros = kNeedsAccumGradQ || key_start == 0 || + (p.num_splits_key_device() > 1); + accumulateInGmem( + isLastColumn ? shared_storage.gradQ_epilogue_lastIter() + : shared_storage.gradQ_epilogue(), + accum, + output_it, + storage_contains_zeros, + warp_id, + lane_id); + } + } + ///////////////////////////////////////////////////////////////////////////////////////////////// + // GradK matmul + // + // grad_k[i_start:i_end] += tmp.transpose(-2, -1) @ q_i + ///////////////////////////////////////////////////////////////////////////////////////////////// + rematerializeThreadIds(); + + constexpr bool kSingleIterationGradK = + kMaxK <= MatmulGradK::ThreadblockShape::kN; + for (int col = 0; col < (kSingleIterationGradK ? 1 : p.head_dim); + col += MatmulGradK::ThreadblockShape::kN) { + using Mma = typename MatmulGradK::Mma; + using AccumTileGmem = typename MatmulGradQ::AccumTileGmem; + + cutlass::gemm::GemmCoord problem_size( + num_keys_in_block, + false ? MatmulGradK::ThreadblockShape::kN : p.head_dim - col, + num_queries_in_block); + auto createEpilogueIter = [&]() { + return typename MatmulGradK::OutputTileIterator( + typename MatmulGradK::OutputTileIterator::Params{p.gK_strideM()}, + p.grad_key_ptr + key_start * p.gK_strideM() + col, + {num_keys_in_block, + false ? MatmulGradK::ThreadblockShape::kN : p.head_dim - col}, + thread_id); + }; + + // q_i + typename Mma::IteratorB iterator_B( + {int32_t(p.q_strideM)}, + p.query_ptr + query_start * p.q_strideM + col, + {problem_size.k(), problem_size.n()}, + thread_id, + no_offset); + + auto getTmp = [&](int) { return &shared_storage.tmp_shared_storage(); }; + auto getTmpT = [&](int) { return &shared_storage.tmpT_shared_storage(); }; + // this is basically: + // opA = kIsTransposedA ? getTmp() : getTmpT(); + bool constexpr kIsTransposedA = + MatmulGradK::DefaultMmaFromSmem::kIsTransposedA; + auto& opA = *call_conditional< + kIsTransposedA, + decltype(getTmp), + decltype(getTmpT)>::apply(getTmp, getTmpT, 0); + Mma mma( + // operand A: dSij.T + opA.accum_ref(), + // operand B: Qi + shared_storage.mm_gradK().operand_B_ref(), + thread_id, + warp_id, + lane_id); + + int storage_id = col / MatmulGradK::ThreadblockShape::kN; + AccumTileGmem gmem_tile{ + p.workspace + storage_id * AccumTileGmem::kElementsStored}; + if (!kOutputInRF) { + if (isFirstQuery || !kNeedsAccumGradK) { + output_frags.gradK.clear(); + } else { + gmem_tile.load(output_frags.gradK, thread_id); + } + } + mma.set_prologue_done(kPrologueGK); + + auto gemm_k_iterations = + (problem_size.k() + Mma::Shape::kK - 1) / Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + __syncthreads(); + + mma(gemm_k_iterations, + output_frags.gradK, + iterator_B, + output_frags.gradK); + __syncthreads(); + bool isLastColumn = kSingleIterationGradK || + col + MatmulGradK::ThreadblockShape::kN >= p.head_dim; + if (kPrologueGK && !isLastColumn) { + prologueGradK(col + MatmulGradK::ThreadblockShape::kN); + } + + if (kPrologueQK && isLastColumn) { + int32_t next_query, next_key; + incrIteration(p, query_start, key_start, next_query, next_key); + DISPATCH_BOOL( + next_key != key_start, kForceReloadK, ([&]() { + prologueQkNextIteration( + shared_storage, p, next_query, next_key, warp_id, lane_id); + })); + } + + // Output results + if (!kOutputInRF) { + if (kNeedsAccumGradK && !isLastQuery) { + gmem_tile.store(output_frags.gradK, thread_id); + } else { + accumulateInGmem( + isLastColumn ? shared_storage.gradK_epilogue_final() + : shared_storage.gradK_epilogue(), + output_frags.gradK, + createEpilogueIter(), + isFirstQuery || kNeedsAccumGradK, + warp_id, + lane_id); + __syncthreads(); + } + } + } + } + + static CUTLASS_DEVICE int32_t getQueryStartShift(Params const& p) { + if (p.custom_mask_type == NoCustomMask && p.num_splits_key_device() > 1) { + return (p.split_key_device() * kBlockSizeI) % getQueryEnd(p); + } + return 0; + } + + // Iteration order logic + static CUTLASS_DEVICE int32_t + getQueryStart(Params const& p, int32_t key_start) { + return getSmallestQueryForKey(p, key_start) + getQueryStartShift(p); + }; + static CUTLASS_DEVICE int32_t getQueryEnd(Params const& p) { + return align_up(p.num_queries, kBlockSizeI); + }; + + static CUTLASS_DEVICE int32_t + getSmallestQueryForKey(Params const& p, int32_t key_start) { + if (p.custom_mask_type == CausalFromTopLeft) { + return (key_start / kBlockSizeI) * kBlockSizeI; + } else if (p.custom_mask_type == CausalFromBottomRight) { + int first_query = + cutlass::fast_max(0, key_start - p.num_keys + p.num_queries); + return (first_query / kBlockSizeI) * kBlockSizeI; + } + return 0; + }; + + // Returns how many kernel blocks will write to a given block in `grad_query` + // This is usually equal to the number of key splits, but can be different + // for instance in the causal case, or varying seqlen + static CUTLASS_DEVICE int32_t + getNumParallelBlocksForQuery(Params const& p, int32_t query_start) { + int16_t num_key_blocks = ceil_div(p.num_keys, kBlockSizeJ); + if (p.custom_mask_type == CausalFromTopLeft) { + int32_t last_key_for_block = query_start + kBlockSizeI - 1; + last_key_for_block = cutlass::fast_min(last_key_for_block, p.num_keys); + num_key_blocks = ceil_div(last_key_for_block, kBlockSizeJ); + } else if (p.custom_mask_type == CausalFromBottomRight) { + int32_t last_key_for_block = + query_start + (kBlockSizeI - 1) + (1 + p.num_keys - p.num_queries); + last_key_for_block = cutlass::fast_min(last_key_for_block, p.num_keys); + num_key_blocks = ceil_div(last_key_for_block, kBlockSizeJ); + } + return cutlass::fast_min(p.num_splits_key_device(), num_key_blocks); + }; + + // Returns the next block to process + static CUTLASS_DEVICE void incrIteration( + Params const& p, + int32_t query_start, + int32_t key_start, + int32_t& next_query, + int32_t& next_key) { + next_query = query_start + kBlockSizeI; + next_key = key_start; + auto query_shift = getQueryStartShift(p); + // Wrap around + if (query_shift) { + if (next_query >= p.num_queries) { + next_query = getSmallestQueryForKey(p, key_start); + return; + } else if (query_start < query_shift && query_shift <= next_query) { + // jump to next key + } else { + return; + } + } else { + if (next_query < p.num_queries) { + return; + } + // jump to next key + } + // Next key + next_key = key_start + p.num_splits_key_device() * kBlockSizeJ; + next_query = getQueryStart(p, next_key); + } + + template + static CUTLASS_DEVICE void prologueQkNextIteration( + SharedStorage& shared_storage, + Params const& p, + int32_t query_start, + int32_t key_start, + uint8_t warp_id, + uint8_t lane_id) { + if (query_start >= p.num_queries || key_start >= p.num_keys) { + return; + } + + static constexpr bool kReloadK = + kForceReloadK || !MatmulQK::Mma::kSmemContainsEntireMat; + int thread_id = 32 * warp_id + lane_id; + typename MatmulQK::Mma::IteratorA iterator_A( + {int32_t(p.k_strideM)}, + p.key_ptr + key_start * p.k_strideM, + {p.num_keys - key_start, p.head_dim}, + thread_id, + cutlass::MatrixCoord{0, 0}); + + typename MatmulQK::Mma::IteratorB iterator_B( + {int32_t(p.q_strideM)}, + p.query_ptr + query_start * p.q_strideM, + {p.head_dim, p.num_queries - query_start}, + thread_id, + cutlass::MatrixCoord{0, 0}); + + MatmulQK::Mma::template prologue( + shared_storage.mm_qk_k(), + shared_storage.mm_qk_q(), + iterator_A, + iterator_B, + thread_id, + p.head_dim); + } + + template + static CUTLASS_DEVICE void writeFragsToGmem( + SharedStorage& shared_storage, + OutputFragments& output_frags, + Params const& p, + int32_t key_start, + uint8_t warp_id, + uint8_t lane_id) { + uint16_t thread_id = 32 * warp_id + lane_id; + int32_t num_keys_in_block = skipBoundsChecks + ? MatmulQK::Mma::Shape::kM + : cutlass::fast_min( + (int32_t)MatmulQK::Mma::Shape::kM, p.num_keys - key_start); + typename MatmulGradV::OutputTileIterator outputV_it( + typename MatmulGradV::OutputTileIterator::Params{p.gV_strideM()}, + p.grad_value_ptr + key_start * p.gV_strideM(), + {num_keys_in_block, p.head_dim_value}, + thread_id); + + accumulateInGmem( + shared_storage.gradV_epilogue_final(), + output_frags.gradV, + outputV_it, + true, + warp_id, + lane_id); + + typename MatmulGradK::OutputTileIterator outputK_it( + typename MatmulGradK::OutputTileIterator::Params{p.gK_strideM()}, + p.grad_key_ptr + key_start * p.gK_strideM(), + {num_keys_in_block, + false ? MatmulGradK::ThreadblockShape::kN : p.head_dim}, + thread_id); + accumulateInGmem( + shared_storage.gradK_epilogue_final(), + output_frags.gradK, + outputK_it, + true, + warp_id, + lane_id); + } + + template + static CUTLASS_DEVICE void accumulateInGmem( + typename MatmulT::DefaultEpilogue::SharedStorage& epilogue_smem, + typename MatmulT::Mma::FragmentC const& accum, + typename MatmulT::OutputTileIterator output_it, + bool first, + uint8_t warp_id, + uint8_t lane_id) { + using DefaultEpilogue = typename MatmulT::DefaultEpilogue; + using DefaultOutputOp = typename MatmulT::DefaultOutputOp; + using Mma = typename MatmulT::Mma; + int thread_id = 32 * warp_id + lane_id; + DISPATCH_BOOL( + first, kIsFirst, ([&]() { + static constexpr auto ScaleType = kIsFirst::value + ? cutlass::epilogue::thread::ScaleType::Nothing + : cutlass::epilogue::thread::ScaleType::NoBetaScaling; + using EpilogueOutputOp = + typename cutlass::epilogue::thread::LinearCombination< + typename DefaultOutputOp::ElementOutput, + DefaultOutputOp::kCount, + typename DefaultOutputOp::ElementAccumulator, + typename DefaultOutputOp::ElementCompute, + ScaleType>; + using Epilogue = + typename cutlass::epilogue::threadblock::EpiloguePipelined< + typename DefaultEpilogue::Shape, + typename Mma::Operator, + DefaultEpilogue::kPartitionsK, + typename MatmulT::OutputTileIterator, + typename DefaultEpilogue::AccumulatorFragmentIterator, + typename DefaultEpilogue::WarpTileIterator, + typename DefaultEpilogue::SharedLoadIterator, + EpilogueOutputOp, + typename DefaultEpilogue::Padding, + DefaultEpilogue::kFragmentsPerIteration, + true // IterationsUnroll + >; + EpilogueOutputOp rescale({1, 1}); + Epilogue epilogue(epilogue_smem, thread_id, warp_id, lane_id); + epilogue(rescale, output_it, accum, output_it); + })); + } + + template + static CUTLASS_DEVICE void computeDelta( + Params const& p, + int32_t query_start, + uint8_t warp_id, + uint8_t lane_id) { + // Each thread computes one value for Delta + // Depending on warp configuration, we might have multiple + // threads of the same warp working on the same row + using AccessType = cutlass::Array; + static_assert(kNumThreads >= kBlockSizeI, ""); + static constexpr int kNumThreadsPerLine = kNumThreads / kBlockSizeI; + int16_t thread_id = 32 * warp_id + lane_id; + + int16_t laneFirstCol = kElementsPerAccess * (lane_id % kNumThreadsPerLine); + int16_t laneRow = thread_id / kNumThreadsPerLine; + bool rowPred = (query_start + laneRow) < p.num_queries; + bool pred = rowPred; + + // on windows, previous syntax __restrict__ AccessType* + // resulted in error: "restrict" is not allowed + const AccessType* __restrict__ grad_output_ptr = + reinterpret_cast( + p.grad_output_ptr + (query_start + laneRow) * p.gO_strideM + + laneFirstCol); + const AccessType* __restrict__ output_ptr = + reinterpret_cast( + p.output_ptr + (query_start + laneRow) * p.o_strideM() + + laneFirstCol); + + static constexpr int64_t kMaxIters = + kMaxK / (kElementsPerAccess * kNumThreadsPerLine); + constexpr int kPipelineStages = 2; + accum_t delta_value = accum_t(0); + using GlobalLoad = + cutlass::arch::global_load; + AccessType frag_grad_output[kPipelineStages]; + AccessType frag_output[kPipelineStages]; + + auto loadAndIncrement = [&](int ld_pos, bool is_valid) { + frag_grad_output[ld_pos].clear(); + frag_output[ld_pos].clear(); + GlobalLoad(frag_grad_output[ld_pos], grad_output_ptr, is_valid); + GlobalLoad(frag_output[ld_pos], output_ptr, is_valid); + grad_output_ptr += kNumThreadsPerLine; + output_ptr += kNumThreadsPerLine; + }; + + CUTLASS_PRAGMA_UNROLL + for (int iter = 0; iter < kPipelineStages - 1; ++iter) { + int ld_pos = iter % kPipelineStages; + pred = pred && + (laneFirstCol + iter * kElementsPerAccess * kNumThreadsPerLine) < + p.head_dim_value; + loadAndIncrement(ld_pos, pred); + } + auto columnIteration = [&](int iter) { + // Load for next iter + int ld_pos = (iter + kPipelineStages - 1) % kPipelineStages; + pred = pred && + (laneFirstCol + + (iter + kPipelineStages - 1) * kElementsPerAccess * + kNumThreadsPerLine) < p.head_dim_value; + loadAndIncrement(ld_pos, pred); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < AccessType::kElements; ++i) { + delta_value += accum_t(frag_output[iter % kPipelineStages][i]) * + accum_t(frag_grad_output[iter % kPipelineStages][i]); + } + }; + + // If we have a small lower-bound for K, we can unroll the loop + if (kMaxK <= 256) { + CUTLASS_PRAGMA_UNROLL + for (int iter = 0; iter < kMaxIters; ++iter) { + columnIteration(iter); + } + } else { + int num_iters = + ceil_div(p.head_dim_value, kElementsPerAccess * kNumThreadsPerLine) * + (kElementsPerAccess * kNumThreadsPerLine); + for (int iter = 0; iter < num_iters; ++iter) { + columnIteration(iter); + } + } + + // Reduce between workers + static_assert( + kNumThreadsPerLine == 1 || kNumThreadsPerLine == 2 || + kNumThreadsPerLine == 4, + ""); + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < kNumThreadsPerLine; i *= 2) { + delta_value = delta_value + __shfl_xor_sync(0xffffffff, delta_value, i); + } + + // Store in gmem + if (rowPred) { + p.delta_ptr[query_start + laneRow] = delta_value; + } + } +}; + +template +__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) + attention_kernel_backward_batched_impl(typename AK::Params p) { + if (!p.advance_to_block()) { + return; + } + AK::attention_kernel(p); +} + +template +__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) + attention_kernel_backward_batched(typename AK::Params params); diff --git a/examples/41_fused_multi_head_attention/kernel_forward.h b/examples/41_fused_multi_head_attention/kernel_forward.h new file mode 100644 index 0000000000..71d79415e9 --- /dev/null +++ b/examples/41_fused_multi_head_attention/kernel_forward.h @@ -0,0 +1,1322 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#ifdef HAS_PYTORCH +#include +#include +#endif + +#include +#include +#include +#include + +#include "cutlass/fast_math.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/vector.h" +#include "cutlass/matrix.h" +#include "cutlass/numeric_types.h" +#include "cutlass/tensor_ref.h" + +#include "cutlass/epilogue/threadblock/default_epilogue_simt.h" +#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" +#include "cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h" +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/gemm/kernel/default_gemm.h" +#include "cutlass/gemm/threadblock/default_mma.h" +#include "cutlass/gemm/threadblock/default_mma_core_simt.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm70.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm75.h" +#include "cutlass/gemm/threadblock/default_mma_core_sm80.h" +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/platform/platform.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" +#include "debug_utils.h" +#include "epilogue/epilogue_pipelined.h" +#include "epilogue/epilogue_rescale_output.h" +#include "gemm/custom_mma.h" +#include "gemm/find_default_mma.h" +#include "gemm/mma_from_smem.h" +#include "gemm_kernel_utils.h" +#include "transform/tile_smem_loader.h" + +using namespace gemm_kernel_utils; + +namespace { +template +constexpr int getWarpsPerSmFw() { + return ( + Arch::kMinComputeCapability >= 80 && + !cutlass::platform::is_same::value + ? 16 + : 12); +} +static CUTLASS_DEVICE float atomicMaxFloat(float* addr, float value) { + // source: https://stackoverflow.com/a/51549250 + return (value >= 0) + ? __int_as_float(atomicMax((int*)addr, __float_as_int(value))) + : __uint_as_float(atomicMin((unsigned int*)addr, __float_as_uint(value))); +} +} // namespace + +// If ToBatchHookType_ is supplied other than this default (which is +// never the case in the xformers library) then the user is +// defining the logic which each block uses to find its data to work on, +// with the advance_to_batch function with the following signature. +// It should return false if there is no work to do for this block. +// In general this will not work with saving for backward due to fixed layout +// for logsumexp and incompatible rngs for dropout, so is likely only useful for +// custom inference. +struct DefaultToBatchHook { + template + CUTLASS_DEVICE static bool advance_to_batch( + Params&, + int64_t& /* q_start */, + int64_t& /* k_start */) { + return true; + } +}; + +template < + // The datatype of Q/K/V + typename scalar_t_, + // Architecture we are targeting (eg `cutlass::arch::Sm80`) + typename ArchTag, + // If Q/K/V are correctly aligned in memory and we can run a fast kernel + bool isAligned_, + int kQueriesPerBlock_, + int kKeysPerBlock_, + // upperbound on `max(value.shape[-1], query.shape[-1])` + int kMaxK_ = (int)cutlass::platform::numeric_limits::max(), + // This is quite slower on V100 for some reason + // Set to false if you know at compile-time you will never need dropout + bool kSupportsDropout_ = true, + bool kSupportsBias_ = true, + typename ToBatchHookType_ = DefaultToBatchHook> +struct AttentionKernel { + enum CustomMaskType { + NoCustomMask = 0, + CausalFromTopLeft = 1, + CausalFromBottomRight = 2, + NumCustomMaskTypes, + }; + + using scalar_t = scalar_t_; + using accum_t = float; + using lse_scalar_t = float; + using output_t = scalar_t; + // Accumulator between 2 iterations + // Using `accum_t` improves perf on f16 at the cost of + // numerical errors + using output_accum_t = accum_t; + static constexpr bool kSupportsDropout = kSupportsDropout_; + static constexpr bool kSupportsBias = kSupportsBias_; + static constexpr int kKeysPerBlock = kKeysPerBlock_; + static constexpr int kQueriesPerBlock = kQueriesPerBlock_; + static constexpr int kMaxK = kMaxK_; + static constexpr bool kIsAligned = isAligned_; + static constexpr bool kSingleValueIteration = kMaxK <= kKeysPerBlock; + static constexpr int32_t kAlignLSE = 32; // block size of backward + static constexpr bool kIsHalf = cutlass::sizeof_bits::value == 16; + static constexpr bool kPreloadV = + ArchTag::kMinComputeCapability >= 80 && kIsHalf; + static constexpr bool kKeepOutputInRF = kSingleValueIteration; + static constexpr bool kNeedsOutputAccumulatorBuffer = !kKeepOutputInRF && + !cutlass::platform::is_same::value; + + static_assert(kQueriesPerBlock % 32 == 0, ""); + static_assert(kKeysPerBlock % 32 == 0, ""); + static constexpr int kNumWarpsPerBlock = + kQueriesPerBlock * kKeysPerBlock / (32 * 32); + static constexpr int kWarpSize = 32; + + // Launch bounds + static constexpr int kNumThreads = kWarpSize * kNumWarpsPerBlock; + static constexpr int kMinBlocksPerSm = + getWarpsPerSmFw() / kNumWarpsPerBlock; + + struct Params { + // Input tensors + scalar_t* query_ptr = nullptr; // [num_queries, num_heads, head_dim] + scalar_t* key_ptr = nullptr; // [num_keys, num_heads, head_dim] + scalar_t* value_ptr = nullptr; // [num_keys, num_heads, head_dim_value] + scalar_t* attn_bias_ptr = nullptr; // [num_heads, num_queries, num_keys] + int32_t* seqstart_q_ptr = nullptr; + int32_t* seqstart_k_ptr = nullptr; + + int32_t* seqlen_k_ptr = nullptr; + uint32_t causal_diagonal_offset = 0; + + // Output tensors + output_t* output_ptr = nullptr; // [num_queries, num_heads, head_dim_value] + // [num_queries, num_heads, head_dim_value] + output_accum_t* output_accum_ptr = nullptr; + // [num_heads, num_queries] - can be null + lse_scalar_t* logsumexp_ptr = nullptr; + + // Scale + accum_t scale = 0.0; + + // Dimensions/strides + int32_t head_dim = 0; + int32_t head_dim_value = 0; + int32_t num_queries = 0; + int32_t num_keys = 0; + int32_t num_keys_absolute = 0; + + uint8_t custom_mask_type = NoCustomMask; + + int32_t q_strideM = 0; + int32_t k_strideM = 0; + int32_t v_strideM = 0; + int32_t bias_strideM = 0; + + int32_t o_strideM = 0; + + // Everything below is only used in `advance_to_block` + // and shouldn't use registers + int32_t q_strideH = 0; + int32_t k_strideH = 0; + int32_t v_strideH = 0; + int64_t bias_strideH = 0; + + int64_t q_strideB = 0; + int64_t k_strideB = 0; + int64_t v_strideB = 0; + int64_t bias_strideB = 0; + + int32_t num_batches = 0; + int32_t num_heads = 0; + + // dropout + bool use_dropout = false; + unsigned long long dropout_batch_head_rng_offset = 0; + float dropout_prob = 0.0f; +#ifdef HAS_PYTORCH + at::PhiloxCudaState rng_engine_inputs = at::PhiloxCudaState(0, 0); +#endif + + // Moves pointers to what we should process + // Returns "false" if there is no work to do + CUTLASS_DEVICE bool advance_to_block() { + auto batch_id = blockIdx.z; + auto head_id = blockIdx.y; + auto query_start = blockIdx.x * kQueriesPerBlock; + + auto lse_dim = ceil_div((int32_t)num_queries, kAlignLSE) * kAlignLSE; + + if (kSupportsDropout) { + dropout_batch_head_rng_offset = + batch_id * num_heads * num_queries * num_keys + + head_id * num_queries * num_keys; + } + + int64_t q_start = 0, k_start = 0; + // Advance to current batch - in case of different sequence lengths + constexpr bool kToBatchHook = + !cutlass::platform::is_same:: + value; + if (kToBatchHook) { + // Call out to a custom implementation. + if (!ToBatchHookType_::advance_to_batch(*this, q_start, k_start)) { + return false; + } + } else if (seqstart_q_ptr != nullptr) { + assert(seqstart_k_ptr != nullptr); + seqstart_q_ptr += batch_id; + + q_start = seqstart_q_ptr[0]; + int64_t q_next_start = seqstart_q_ptr[1]; + int64_t k_end; + seqstart_k_ptr += batch_id; + + if (seqlen_k_ptr) { + k_start = seqstart_k_ptr[0]; + k_end = k_start + seqlen_k_ptr[batch_id]; + } else { + k_start = seqstart_k_ptr[0]; + k_end = seqstart_k_ptr[1]; + } + + num_queries = q_next_start - q_start; + num_keys = k_end - k_start; + + if (query_start >= num_queries) { + return false; + } + } else { + query_ptr += batch_id * q_strideB; + key_ptr += batch_id * k_strideB; + value_ptr += batch_id * v_strideB; + output_ptr += int64_t(batch_id * num_queries) * o_strideM; + if (output_accum_ptr != nullptr) { + output_accum_ptr += + int64_t(batch_id * num_queries) * (head_dim_value * num_heads); + } + q_start = 0; + k_start = 0; + } + + // Advance to the current batch / head / query_start + query_ptr += (q_start + query_start) * q_strideM + head_id * q_strideH; + key_ptr += k_start * k_strideM + head_id * k_strideH; + + value_ptr += k_start * v_strideM + head_id * v_strideH; + output_ptr += + int64_t(q_start + query_start) * o_strideM + head_id * head_dim_value; + + if (kSupportsBias && attn_bias_ptr != nullptr) { + attn_bias_ptr += (batch_id * bias_strideB) + (head_id * bias_strideH); + } + if (output_accum_ptr != nullptr) { + output_accum_ptr += + int64_t(q_start + query_start) * (head_dim_value * num_heads) + + head_id * head_dim_value; + } else { + // Accumulate directly in the destination buffer (eg for f32) + output_accum_ptr = (accum_t*)output_ptr; + } + + if (logsumexp_ptr != nullptr) { + // lse[batch_id, head_id, query_start] + logsumexp_ptr += + batch_id * lse_dim * num_heads + head_id * lse_dim + query_start; + } + + // Custom masking + if (custom_mask_type == CausalFromBottomRight) { + causal_diagonal_offset = num_keys - num_queries; + } + // We use num_keys_absolute to index into the rng_state + // We need this index to match between forward and backwards + num_keys_absolute = num_keys; + if (custom_mask_type == CausalFromTopLeft || + custom_mask_type == CausalFromBottomRight) { + // the bottom row of the current block is query_start + kQueriesPerBlock + // the last active key is then query_start + causal_diagonal_offset + + // kQueriesPerBlock so num_keys is the min between actual num_keys and + // this to avoid extra computations + num_keys = cutlass::fast_min( + int32_t(query_start + causal_diagonal_offset + kQueriesPerBlock), + num_keys); + } + + num_queries -= query_start; + num_batches = 0; // no longer used after + + // If num_queries == 1, and there is only one key head we're wasting + // 15/16th of tensor core compute In that case : + // - we only launch kernels for head_id % kQueriesPerBlock == 0 + // - we iterate over heads instead of queries (strideM = strideH) + if (num_queries == 1 && k_strideH == 0 && v_strideH == 0) { + if (head_id % kQueriesPerBlock != 0) + return false; + q_strideM = q_strideH; + num_queries = num_heads; + num_heads = 1; // unused but here for intent + // remove causal since n_query = 1 + // otherwise, offset would change with head ! + custom_mask_type = NoCustomMask; + o_strideM = head_dim_value; + } + + // Make sure the compiler knows these variables are the same on all + // the threads of the warp. + // Only worth doing if they could have been modified above. + query_ptr = warp_uniform(query_ptr); + key_ptr = warp_uniform(key_ptr); + value_ptr = warp_uniform(value_ptr); + if (kSupportsBias) { + attn_bias_ptr = warp_uniform(attn_bias_ptr); + } + output_ptr = warp_uniform(output_ptr); + output_accum_ptr = warp_uniform(output_accum_ptr); + logsumexp_ptr = warp_uniform(logsumexp_ptr); + num_queries = warp_uniform(num_queries); + num_keys = warp_uniform(num_keys); + num_heads = warp_uniform(num_heads); + o_strideM = warp_uniform(o_strideM); + custom_mask_type = warp_uniform(custom_mask_type); + return true; + } + + __host__ dim3 getBlocksGrid() const { + return dim3( + ceil_div(num_queries, (int32_t)kQueriesPerBlock), + num_heads, + num_batches); + } + + __host__ dim3 getThreadsGrid() const { + return dim3(kWarpSize, kNumWarpsPerBlock, 1); + } + }; + + struct MM0 { + /* + In this first matmul, we compute a block of `Q @ K.T`. + While the calculation result is still hot in registers, we update + `mi`, `m_prime`, `s_prime` in shared-memory, and then store this value + into a shared-memory ("AccumulatorSharedStorage") that is used later as + operand A for the second matmul (see MM1) + */ + using GemmType = DefaultGemmType; + + using OpClass = typename GemmType::OpClass; + using DefaultConfig = + typename cutlass::gemm::device::DefaultGemmConfiguration< + OpClass, + ArchTag, + scalar_t, + scalar_t, + scalar_t, // ElementC + accum_t // ElementAccumulator + >; + static constexpr int kAlignmentA = + kIsAligned ? DefaultConfig::kAlignmentA : GemmType::kMinimumAlignment; + static constexpr int kAlignmentB = + kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment; + using ThreadblockShape = cutlass::gemm:: + GemmShape; + using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; + using DefaultMma = typename cutlass::gemm::threadblock::FindDefaultMma< + scalar_t, // ElementA, + cutlass::layout::RowMajor, // LayoutA, + kAlignmentA, + scalar_t, // ElementB, + cutlass::layout::ColumnMajor, // LayoutB, + kAlignmentB, + accum_t, + cutlass::layout::RowMajor, // LayoutC, + OpClass, + ArchTag, // ArchTag + ThreadblockShape, // ThreadblockShape + WarpShape, // WarpShape + typename GemmType::InstructionShape, // InstructionShape + ArchTag::kMinComputeCapability >= 80 && kIsHalf + ? 4 + : DefaultConfig::kStages, + typename GemmType::Operator // Operator + >::DefaultMma; + using MmaCore = typename DefaultMma::MmaCore; + using IteratorA = typename DefaultMma::IteratorA; + using IteratorB = typename DefaultMma::IteratorB; + using DefaultThreadblockMma = typename DefaultMma::ThreadblockMma; + using Mma = typename cutlass::platform::conditional< + kSingleValueIteration, + typename MakeCustomMma::Mma, + DefaultThreadblockMma>::type; + using AccumLambdaIterator = typename DefaultMmaAccumLambdaIterator< + typename Mma::Operator::IteratorC, + accum_t, + kWarpSize>::Iterator; + static_assert( + MmaCore::WarpCount::kM * MmaCore::WarpCount::kN * + MmaCore::WarpCount::kK == + kNumWarpsPerBlock, + ""); + + // used for efficient load of bias tile Bij from global to shared memory + using BiasLoader = TileSmemLoader< + scalar_t, + cutlass::MatrixShape, + MmaCore::kThreads, + // input restriction: kv_len has to be a multiple of this value + 128 / cutlass::sizeof_bits::value>; + + // Epilogue to store to shared-memory in a format that we can use later for + // the second matmul + using B2bGemm = typename cutlass::gemm::threadblock::B2bGemm< + typename Mma::Operator::IteratorC, + typename Mma::Operator, + scalar_t, + WarpShape, + ThreadblockShape>; + using AccumulatorSharedStorage = typename B2bGemm::AccumulatorSharedStorage; + }; + + struct MM1 { + /** + Second matmul: perform `attn @ V` where `attn` is the attention (not + normalized) and stored in shared memory + */ + using GemmType = DefaultGemmType; + + using OpClass = typename GemmType::OpClass; + using DefaultConfig = + typename cutlass::gemm::device::DefaultGemmConfiguration< + OpClass, + ArchTag, + scalar_t, + scalar_t, + output_accum_t, // ElementC + accum_t // ElementAccumulator + >; + static constexpr int kAlignmentA = DefaultConfig::kAlignmentA; // from smem + static constexpr int kAlignmentB = + kIsAligned ? DefaultConfig::kAlignmentB : GemmType::kMinimumAlignment; + using ThreadblockShape = cutlass::gemm:: + GemmShape; + using WarpShape = cutlass::gemm::GemmShape<32, 32, GemmType::WarpK>; + using InstructionShape = typename GemmType::InstructionShape; + + using LayoutB = cutlass::layout::RowMajor; + using DefaultGemm = cutlass::gemm::kernel::DefaultGemm< + scalar_t, // ElementA, + cutlass::layout::RowMajor, // LayoutA, + kAlignmentA, + scalar_t, // ElementB, + LayoutB, // LayoutB, + kAlignmentB, + output_accum_t, + cutlass::layout::RowMajor, // LayoutC, + accum_t, + OpClass, + ArchTag, + ThreadblockShape, + WarpShape, + typename GemmType::InstructionShape, + typename DefaultConfig::EpilogueOutputOp, + void, // ThreadblockSwizzle - not used + ArchTag::kMinComputeCapability >= 80 && kIsHalf + ? 4 + : DefaultConfig::kStages, + false, // SplitKSerial + typename GemmType::Operator>; + + using WarpIteratorA = typename cutlass::gemm::threadblock:: + DefaultWarpIteratorAFromSharedMemory< + typename DefaultGemm::Mma::Policy::Operator::Shape, // WarpShape + typename DefaultGemm::Mma::Policy::Operator::InstructionShape, + typename DefaultGemm::Mma::Policy::Operator::IteratorA, + typename DefaultGemm::Mma::Policy>::WarpIterator; + using DefaultMmaFromSmem = + typename cutlass::gemm::threadblock::DefaultMmaFromSharedMemory< + typename DefaultGemm::Mma, + MM0::AccumulatorSharedStorage::Shape::kN, // kMaxK + WarpIteratorA, + false>; // kScaleOperandA + using Mma = typename DefaultMmaFromSmem::Mma; + using IteratorB = typename Mma::IteratorB; + using WarpCount = typename Mma::WarpCount; + static_assert( + WarpCount::kM * WarpCount::kN * WarpCount::kK == kNumWarpsPerBlock, + ""); + + using DefaultEpilogue = typename DefaultGemm::Epilogue; + using OutputTileIterator = + typename cutlass::epilogue::threadblock::PredicatedTileIterator< + typename DefaultEpilogue::OutputTileIterator::ThreadMap, + output_t>; + using OutputTileIteratorAccum = + typename cutlass::epilogue::threadblock::PredicatedTileIterator< + typename DefaultEpilogue::OutputTileIterator::ThreadMap, + output_accum_t>; + }; + + static constexpr int64_t kAlignmentQ = MM0::kAlignmentA; + static constexpr int64_t kAlignmentK = MM0::kAlignmentB; + static constexpr int64_t kAlignmentV = 1; + + // Shared storage - depends on kernel params + struct ScalingCoefs { + cutlass::Array m_prime; + cutlass::Array s_prime; + cutlass::Array mi; + cutlass::Array out_rescale; + cutlass::Array + addition_storage; + }; + + struct SharedStorageEpilogueAtEnd : ScalingCoefs { + struct SharedStorageAfterMM0 { + // Everything here might be overwritten during MM0 + union { + typename MM0::BiasLoader::SmemTile bias; + typename MM0::AccumulatorSharedStorage si; + }; + typename MM1::Mma::SharedStorage mm1; + }; + + union { + typename MM0::Mma::SharedStorage mm0; + SharedStorageAfterMM0 after_mm0; + typename MM1::DefaultEpilogue::SharedStorage epilogue; + }; + + CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage& + epilogue_shared_storage() { + return epilogue; + } + }; + + struct SharedStorageEpilogueInLoop : ScalingCoefs { + struct SharedStorageAfterMM0 { + // Everything here might be overwritten during MM0 + union { + typename MM0::BiasLoader::SmemTile bias; + typename MM0::AccumulatorSharedStorage si; + }; + typename MM1::Mma::SharedStorage mm1; + typename MM1::DefaultEpilogue::SharedStorage epilogue; + }; + + union { + typename MM0::Mma::SharedStorage mm0; + SharedStorageAfterMM0 after_mm0; + }; + + CUTLASS_DEVICE typename MM1::DefaultEpilogue::SharedStorage& + epilogue_shared_storage() { + return after_mm0.epilogue; + } + }; + + using SharedStorage = typename cutlass::platform::conditional< + kSingleValueIteration || kKeepOutputInRF, + SharedStorageEpilogueAtEnd, + SharedStorageEpilogueInLoop>::type; + + static bool __host__ check_supported(Params const& p) { + CHECK_ALIGNED_PTR(p.query_ptr, kAlignmentQ); + CHECK_ALIGNED_PTR(p.key_ptr, kAlignmentK); + CHECK_ALIGNED_PTR(p.value_ptr, kAlignmentV); + if (kSupportsBias) { + CHECK_ALIGNED_PTR(p.attn_bias_ptr, kAlignmentQ); + XFORMERS_CHECK( + p.num_batches <= 1 || p.bias_strideB % kAlignmentQ == 0, + "attn_bias is not correctly aligned (strideB)"); + XFORMERS_CHECK( + p.num_heads <= 1 || p.bias_strideH % kAlignmentQ == 0, + "attn_bias is not correctly aligned (strideH)"); + XFORMERS_CHECK( + p.bias_strideM % kAlignmentQ == 0, + "attn_bias is not correctly aligned"); + } + XFORMERS_CHECK( + p.q_strideM % kAlignmentQ == 0, + "query is not correctly aligned (strideM)"); + XFORMERS_CHECK( + p.k_strideM % kAlignmentK == 0, + "key is not correctly aligned (strideM)"); + XFORMERS_CHECK( + p.v_strideM % kAlignmentV == 0, + "value is not correctly aligned (strideM)"); + XFORMERS_CHECK( + p.num_heads <= 1 || p.q_strideH % kAlignmentQ == 0, + "query is not correctly aligned (strideH)"); + XFORMERS_CHECK( + p.num_heads <= 1 || p.k_strideH % kAlignmentK == 0, + "key is not correctly aligned (strideH)"); + XFORMERS_CHECK( + p.num_heads <= 1 || p.v_strideH % kAlignmentV == 0, + "value is not correctly aligned (strideH)"); + XFORMERS_CHECK( + p.custom_mask_type < NumCustomMaskTypes, + "invalid value for `custom_mask_type`"); + return true; + } + + static void CUTLASS_DEVICE attention_kernel(Params& p) { + // In this block, we will only ever: + // - read query[query_start:query_end, :] + // - write to output[query_start:query_end, :] + + extern __shared__ char smem_buffer[]; + SharedStorage& shared_storage = *((SharedStorage*)smem_buffer); + auto& m_prime = shared_storage.m_prime; + auto& s_prime = shared_storage.s_prime; + auto& mi = shared_storage.mi; + auto& out_rescale = shared_storage.out_rescale; + const uint32_t query_start = blockIdx.x * kQueriesPerBlock; + + static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, ""); + if (thread_id() < kQueriesPerBlock) { + s_prime[thread_id()] = accum_t(0); + out_rescale[thread_id()] = accum_t(1.0); + m_prime[thread_id()] = + -cutlass::platform::numeric_limits::infinity(); + mi[thread_id()] = -cutlass::platform::numeric_limits::infinity(); + } + typename MM1::Mma::FragmentC accum_o; + accum_o.clear(); + + auto createOutputIter = [&](int col) -> typename MM1::OutputTileIterator { + using OutputTileIterator = typename MM1::OutputTileIterator; + return OutputTileIterator( + typename OutputTileIterator::Params{(int32_t)p.o_strideM}, + p.output_ptr, + typename OutputTileIterator::TensorCoord{ + p.num_queries, p.head_dim_value}, + thread_id(), + {0, col}); + }; + + auto createOutputAccumIter = [&](int col) -> + typename MM1::OutputTileIteratorAccum { + using OutputTileIteratorAccum = typename MM1::OutputTileIteratorAccum; + return OutputTileIteratorAccum( + typename OutputTileIteratorAccum::Params{ + (int32_t)(p.head_dim_value * p.num_heads)}, + p.output_accum_ptr, + typename OutputTileIteratorAccum::TensorCoord{ + p.num_queries, p.head_dim_value}, + thread_id(), + {0, col}); + }; + +#ifdef HAS_PYTORCH + curandStatePhilox4_32_10_t curand_state_init; + if (kSupportsDropout && p.use_dropout) { + const auto seeds = at::cuda::philox::unpack(p.rng_engine_inputs); + + // each element of the attention matrix P with shape + // (batch_sz, n_heads, n_queries, n_keys) is associated with a single + // offset in RNG sequence. we initialize the RNG state with offset that + // starts at the beginning of a (n_queries, n_keys) matrix for this + // block's batch_id and head_id + // initializing rng state is very expensive, so we run once per kernel, + // rather than once per iteration. each iteration takes a copy of the + // initialized RNG state and offsets it as needed. + curand_init( + std::get<0>(seeds), + 0, + std::get<1>(seeds) + p.dropout_batch_head_rng_offset, + &curand_state_init); + } +#endif + + // Iterate through keys + for (int32_t iter_key_start = 0; iter_key_start < p.num_keys; + iter_key_start += kKeysPerBlock) { + int32_t problem_size_0_m = + cutlass::fast_min((int32_t)kQueriesPerBlock, p.num_queries); + int32_t problem_size_0_n = cutlass::fast_min( + int32_t(kKeysPerBlock), p.num_keys - iter_key_start); + int32_t const& problem_size_0_k = p.head_dim; + int32_t const& problem_size_1_n = p.head_dim_value; + int32_t const& problem_size_1_k = problem_size_0_n; + + auto prologueV = [&](int blockN) { + typename MM1::Mma::IteratorB iterator_V( + typename MM1::IteratorB::Params{typename MM1::LayoutB(p.v_strideM)}, + p.value_ptr + iter_key_start * p.v_strideM, + {problem_size_1_k, problem_size_1_n}, + thread_id(), + cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN}); + MM1::Mma::prologue( + shared_storage.after_mm0.mm1, + iterator_V, + thread_id(), + problem_size_1_k); + }; + + __syncthreads(); // Need to have shared memory initialized, and `m_prime` + // updated from end of prev iter + // + // MATMUL: Q.K_t + // + // Computes the block-matrix product of: + // (a) query[query_start:query_end, :] + // with + // (b) key[iter_key_start:iter_key_start + kKeysPerBlock] + // and stores that into `shared_storage.si` + // + + // Compute threadblock location + cutlass::gemm::GemmCoord tb_tile_offset = {0, 0, 0}; + + cutlass::MatrixCoord tb_offset_A{ + tb_tile_offset.m() * MM0::Mma::Shape::kM, tb_tile_offset.k()}; + + cutlass::MatrixCoord tb_offset_B{ + tb_tile_offset.k(), tb_tile_offset.n() * MM0::Mma::Shape::kN}; + + // Construct iterators to A and B operands + typename MM0::IteratorA iterator_A( + typename MM0::IteratorA::Params( + typename MM0::MmaCore::LayoutA(p.q_strideM)), + p.query_ptr, + {problem_size_0_m, problem_size_0_k}, + thread_id(), + tb_offset_A); + + typename MM0::IteratorB iterator_B( + typename MM0::IteratorB::Params( + typename MM0::MmaCore::LayoutB(p.k_strideM)), + p.key_ptr + iter_key_start * p.k_strideM, + {problem_size_0_k, problem_size_0_n}, + thread_id(), + tb_offset_B); + + auto my_warp_id = warp_uniform(warp_id()); + auto my_lane_id = lane_id(); + + // Construct thread-scoped matrix multiply + typename MM0::Mma mma( + shared_storage.mm0, thread_id(), my_warp_id, my_lane_id); + + typename MM0::Mma::FragmentC accum; + + accum.clear(); + + auto gemm_k_iterations = + (problem_size_0_k + MM0::Mma::Shape::kK - 1) / MM0::Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, accum, iterator_A, iterator_B, accum); + __syncthreads(); + + if (kPreloadV) { + prologueV(0); + } else { + MM1::Mma::drain_cp_asyncs(); + } + + typename MM0::Mma::Operator::IteratorC::TensorCoord + iteratorC_tile_offset = { + (tb_tile_offset.m() * MM0::Mma::WarpCount::kM) + + (my_warp_id % MM0::Mma::WarpCount::kM), + (tb_tile_offset.n() * MM0::Mma::WarpCount::kN) + + (my_warp_id / MM0::Mma::WarpCount::kM)}; + + // multiply by scaling factor + if (kSupportsBias) { + accum = + cutlass::multiplies()(p.scale, accum); + } + + // apply attention bias if applicable + if (kSupportsBias && p.attn_bias_ptr != nullptr) { + // load bias tile Bij into shared memory + typename MM0::BiasLoader::GmemTileIterator bias_iter( + {cutlass::layout::RowMajor(p.bias_strideM)}, + // attn_bias_pointer points to matrix of size (n_queries, n_keys) + // for the relevant batch_id and head_id + p.attn_bias_ptr + query_start * p.bias_strideM + iter_key_start, + {problem_size_0_m, problem_size_0_n}, + thread_id()); + cutlass::TensorRef bias_tensor_ref( + shared_storage.after_mm0.bias.data(), + cutlass::layout::RowMajor(MM0::ThreadblockShape::kN)); + typename MM0::BiasLoader::SmemTileIterator smem_tile_iter( + bias_tensor_ref, thread_id()); + MM0::BiasLoader::load(bias_iter, smem_tile_iter); + + // Pij += Bij, Pij is in register fragment and Bij is in shared memory + auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset( + my_lane_id, my_warp_id, iteratorC_tile_offset); + MM0::AccumLambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) {}, + [&](int accum_m, int accum_n, int idx) { + if (accum_m < problem_size_0_m && accum_n < problem_size_0_n) { + accum[idx] += bias_tensor_ref.at({accum_m, accum_n}); + } + }, + [&](int accum_m) {}); + } + + // Mask out last if causal + // This is only needed if upper-right corner of current query / key block + // intersects the mask Coordinates of upper-right corner of current block + // is y=query_start x=min(iter_key_start + kKeysPerBlock, num_keys)) The + // first masked element is x = y + offset -> query_start + offset There is + // intersection (and we need to mask) if min(iter_key_start + + // kKeysPerBlock, num_keys)) >= query_start + offset + if (p.custom_mask_type && + cutlass::fast_min(iter_key_start + kKeysPerBlock, p.num_keys) >= + (query_start + p.causal_diagonal_offset)) { + auto query_start = blockIdx.x * kQueriesPerBlock; + auto lane_offset = MM0::AccumLambdaIterator::get_lane_offset( + my_lane_id, my_warp_id, iteratorC_tile_offset); + int32_t last_col; + MM0::AccumLambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { + // last absolute col is (last absolute query + offset) + // last local col is (last absolute query + offset - + // iter_key_start) + last_col = query_start + accum_m + p.causal_diagonal_offset - + iter_key_start; + }, + [&](int accum_m, int accum_n, int idx) { + if (accum_n > last_col) { + accum[idx] = + -cutlass::platform::numeric_limits::infinity(); + } + }, + [&](int accum_m) {}); + } + // Update `mi` from accum stored in registers + // Also does accum[i] <- exp(accum[i] - mi) + iterative_softmax( + accum_o, + accum, + mi, + m_prime, + s_prime, + out_rescale, + shared_storage.addition_storage, + my_lane_id, + thread_id(), + my_warp_id, + p.num_keys - iter_key_start, + iter_key_start == 0, + iteratorC_tile_offset, + kSupportsBias ? 1.0f : p.scale); + + // Output results to shared-memory + int warp_idx_mn_0 = my_warp_id % + (MM0::Mma::Base::WarpCount::kM * MM0::Mma::Base::WarpCount::kN); + auto output_tile_coords = cutlass::MatrixCoord{ + warp_idx_mn_0 % MM0::Mma::Base::WarpCount::kM, + warp_idx_mn_0 / MM0::Mma::Base::WarpCount::kM}; + + MM0::B2bGemm::accumToSmem( + shared_storage.after_mm0.si, accum, my_lane_id, output_tile_coords); + + __syncthreads(); + +#ifdef HAS_PYTORCH + // apply dropout (if applicable) after we've written Pij to smem. + // dropout is applied by multiplying each element of Pij by: + // - 0 with probability dropout_p + // - 1 / (1 - dropout_p) with probability 1 - dropout_p + // + // for backward purposes we want to be able to map each element of the + // attention matrix to the same random uniform number as the one we used + // in forward, without needing to use the same iteration order or having + // to store the dropout matrix. its possible to do this in registers but + // it ends up being very slow because each thread having noncontiguous + // strips of the Pij tile means we have to skip around a lot, and also + // have to generate a single random number at a time + if (kSupportsDropout && p.use_dropout) { + auto si = shared_storage.after_mm0.si.accum_ref(); + // each thread handles a contiguous sequence of elements from Sij, all + // coming from the same row. the reason they have to come from the same + // row is that the sampling random numbers from a contiguous random + // number sequence is much more efficient than jumping around, and the + // linear offset of each element of S (the global matrix) maps to an + // offset in a random number sequence. for S, the end of a row and the + // beginning of the next have adjacent offsets, but for Sij, this is not + // necessarily the case. + const int num_threads = blockDim.x * blockDim.y * blockDim.z; + const int threads_per_row = + cutlass::fast_min(num_threads / problem_size_0_m, problem_size_0_n); + const int elts_per_thread = cutlass::round_nearest( + cutlass::ceil_div(problem_size_0_n, threads_per_row), 4); + + const int thread_i = thread_id() / threads_per_row; + const int thread_start_j = + (thread_id() % threads_per_row) * elts_per_thread; + + if (thread_i < problem_size_0_m && thread_start_j < problem_size_0_n) { + curandStatePhilox4_32_10_t curand_state = curand_state_init; + skipahead( + static_cast( + (query_start + thread_i) * p.num_keys_absolute + + (iter_key_start + thread_start_j)), + &curand_state); + const float dropout_scale = 1.0 / (1.0 - p.dropout_prob); + + // apply dropout scaling to elements this thread is responsible for, + // in chunks of 4 + for (int sij_start_col_idx = thread_start_j; sij_start_col_idx < + cutlass::fast_min(thread_start_j + elts_per_thread, + problem_size_0_n); + sij_start_col_idx += 4) { + const float4 rand_uniform_quad = curand_uniform4(&curand_state); + + CUTLASS_PRAGMA_UNROLL + for (int quad_idx = 0; quad_idx < 4; ++quad_idx) { + si.at({thread_i, sij_start_col_idx + quad_idx}) *= + static_cast( + dropout_scale * + ((&rand_uniform_quad.x)[quad_idx] > p.dropout_prob)); + } + } + } + __syncthreads(); // p.use_dropout should have same value kernel-wide + } +#endif + + // + // MATMUL: Attn . V + // Run the matmul `attn @ V` for a block of attn and V. + // `attn` is read from shared memory (in `shared_storage_si`) + // `V` is read from global memory (with iterator_B) + // + + const int64_t nBlockN = kSingleValueIteration + ? 1 + : ceil_div( + (int64_t)problem_size_1_n, int64_t(MM1::ThreadblockShape::kN)); + for (int blockN = 0; blockN < nBlockN; ++blockN) { + int gemm_k_iterations = + (problem_size_1_k + MM1::Mma::Shape::kK - 1) / MM1::Mma::Shape::kK; + + // Compute threadblock-scoped matrix multiply-add and store it in accum + // (in registers) + if (!kPreloadV) { + __syncthreads(); // we share shmem between mma and epilogue + } + + typename MM1::Mma::IteratorB iterator_V( + typename MM1::IteratorB::Params{typename MM1::LayoutB(p.v_strideM)}, + p.value_ptr + iter_key_start * p.v_strideM, + {problem_size_1_k, problem_size_1_n}, + thread_id(), + cutlass::MatrixCoord{0, blockN * MM1::Mma::Shape::kN}); + typename MM1::Mma mma_pv( + // operand A: Pij_dropped in shared memory + shared_storage.after_mm0.si.accum_ref(), + // operand B: shared memory staging area for Vj, which is loaded + // from global memory + shared_storage.after_mm0.mm1.operand_B_ref(), + (int)thread_id(), + (int)my_warp_id, + (int)my_lane_id); + mma_pv.set_prologue_done(kPreloadV); + if (!kKeepOutputInRF) { + accum_o.clear(); + } + mma_pv(gemm_k_iterations, accum_o, iterator_V, accum_o); + __syncthreads(); + + if (kPreloadV && !kSingleValueIteration && blockN + 1 < nBlockN) { + prologueV(blockN + 1); + } + + if (!kKeepOutputInRF) { + MM1::Mma::drain_cp_asyncs(); + DISPATCH_BOOL( + iter_key_start == 0, kIsFirst, ([&] { + DISPATCH_BOOL( + (iter_key_start + kKeysPerBlock) >= p.num_keys, + kIsLast, + ([&] { + using DefaultEpilogue = typename MM1::DefaultEpilogue; + using DefaultOp = + typename MM1::DefaultConfig::EpilogueOutputOp; + using ElementCompute = typename DefaultOp::ElementCompute; + using EpilogueOutputOp = typename cutlass::epilogue:: + thread::MemoryEfficientAttentionNormalize< + typename cutlass::platform::conditional< + kIsLast::value, + output_t, + output_accum_t>::type, + output_accum_t, + DefaultOp::kCount, + typename DefaultOp::ElementAccumulator, + ElementCompute, + kIsFirst::value, + kIsLast::value, + cutlass::Array>; + using Epilogue = typename cutlass::epilogue::threadblock:: + EpiloguePipelined< + typename DefaultEpilogue::Shape, + typename MM1::Mma::Operator, + DefaultEpilogue::kPartitionsK, + typename cutlass::platform::conditional< + kIsLast::value, + typename MM1::OutputTileIterator, + typename MM1::OutputTileIteratorAccum>::type, + typename DefaultEpilogue:: + AccumulatorFragmentIterator, + typename DefaultEpilogue::WarpTileIterator, + typename DefaultEpilogue::SharedLoadIterator, + EpilogueOutputOp, + typename DefaultEpilogue::Padding, + DefaultEpilogue::kFragmentsPerIteration, + true, // IterationsUnroll + typename MM1::OutputTileIteratorAccum // Read + // iterator + >; + + int col = blockN * MM1::Mma::Shape::kN; + auto source_iter = createOutputAccumIter(col); + auto dest_iter = call_conditional< + kIsLast::value, + decltype(createOutputIter), + decltype(createOutputAccumIter)>:: + apply(createOutputIter, createOutputAccumIter, col); + EpilogueOutputOp rescale(s_prime, out_rescale); + Epilogue epilogue( + shared_storage.epilogue_shared_storage(), + thread_id(), + my_warp_id, + my_lane_id); + epilogue(rescale, dest_iter, accum_o, source_iter); + })); + })); + if (!kSingleValueIteration) { + __syncthreads(); + } + } + } + __syncthreads(); // we modify `m_prime` after + } + + if (kKeepOutputInRF) { + constexpr bool kIsFirst = true; + constexpr bool kIsLast = true; + using DefaultEpilogue = typename MM1::DefaultEpilogue; + using DefaultOp = typename MM1::DefaultConfig::EpilogueOutputOp; + using ElementCompute = typename DefaultOp::ElementCompute; + using EpilogueOutputOp = + typename cutlass::epilogue::thread::MemoryEfficientAttentionNormalize< + output_t, // output + output_accum_t, // source + DefaultOp::kCount, + typename DefaultOp::ElementAccumulator, // accum + output_accum_t, // compute + kIsFirst, + kIsLast, + cutlass::Array>; + using Epilogue = + typename cutlass::epilogue::threadblock::EpiloguePipelined< + typename DefaultEpilogue::Shape, + typename MM1::Mma::Operator, + DefaultEpilogue::kPartitionsK, + typename MM1::OutputTileIterator, // destination + typename DefaultEpilogue::AccumulatorFragmentIterator, + typename DefaultEpilogue::WarpTileIterator, + typename DefaultEpilogue::SharedLoadIterator, + EpilogueOutputOp, + typename DefaultEpilogue::Padding, + DefaultEpilogue::kFragmentsPerIteration, + true, // IterationsUnroll + typename MM1::OutputTileIteratorAccum // source tile + >; + auto dest_iter = createOutputIter(0); + EpilogueOutputOp rescale(s_prime, out_rescale); + Epilogue epilogue( + shared_storage.epilogue_shared_storage(), + thread_id(), + warp_id(), + lane_id()); + MM1::Mma::drain_cp_asyncs(); + epilogue(rescale, dest_iter, accum_o); + } + + // 7. Calculate logsumexp + // To make the backward easier, we pad logsumexp with `inf` + // this avoids a few bound checks, and is not more expensive during fwd + static_assert(kQueriesPerBlock < kNumWarpsPerBlock * kWarpSize, ""); + if (p.logsumexp_ptr && thread_id() < kQueriesPerBlock) { + auto lse_dim = ceil_div((int32_t)p.num_queries, kAlignLSE) * kAlignLSE; + constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E + if (thread_id() < p.num_queries) { + p.logsumexp_ptr[thread_id()] = accum_t(mi[thread_id()] / kLog2e) + + cutlass::fast_log(accum_t(s_prime[thread_id()])); + } else if (thread_id() < lse_dim) { + p.logsumexp_ptr[thread_id()] = + cutlass::platform::numeric_limits::infinity(); + } + } + } + + template + CUTLASS_DEVICE static void iterative_softmax( + typename WarpIteratorC::Fragment& frag_o, // output so far + typename WarpIteratorC::Fragment& frag, + cutlass::Array& mi, + cutlass::Array& m_prime, + cutlass::Array& s_prime, + cutlass::Array& out_rescale, + cutlass::Array& + addition_storage, + int8_t lane_id, + int8_t thread_id, + int8_t warp_id, + int max_col, + bool is_first, + typename WarpIteratorC::TensorCoord const& tile_offset, + float scaling) { + /* Iterates on the accumulator and corresponding position on result matrix + + (1) Update `mi[r]` to the max value of the row `r` + (2) In a second iteration do the following: + (a) accum <- exp(accum - mi) + (b) m_prime <- exp(m_prime - mi) + (c) s_prime <- s_prime * m_prime + sum(accum) + + All of this is done on registers, before we store all of this + on shared memory for the next matmul with Value. + */ + using Fragment = typename WarpIteratorC::Fragment; + using LambdaIterator = typename DefaultMmaAccumLambdaIterator< + WarpIteratorC, + accum_t, + kWarpSize>::Iterator; + // Convert to `accum_t` (rather than double) + constexpr float kLog2e = 1.4426950408889634074; // log_2(e) = M_LOG2E + + static_assert(kQueriesPerBlock % kNumWarpsPerBlock == 0, ""); + static constexpr int kLinesPerWarp = kQueriesPerBlock / kNumWarpsPerBlock; + + frag = cutlass::multiplies()(scaling * kLog2e, frag); + + auto lane_offset = + LambdaIterator::get_lane_offset(lane_id, warp_id, tile_offset); + + // First update `mi` to the max per-row + { + accum_t max; + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { + max = -cutlass::platform::numeric_limits::infinity(); + }, + [&](int accum_m, int accum_n, int idx) { + if (accum_n < max_col) { + max = cutlass::fast_max(max, frag[idx]); + } + }, + [&](int accum_m) { + // Having 4x atomicMax seems faster than reduce within warp + // first... + atomicMaxFloat(&mi[accum_m], max); + }); + } + + // Make sure we all share the update values for `mi` + __syncthreads(); + + // Doing this `exp` is quite expensive. Let's + // split it across the warps + bool restore_mi_to_minus_inf = false; + if (lane_id < kLinesPerWarp) { + int id = warp_id * kLinesPerWarp + lane_id; + auto m_prime_id = m_prime[id]; + auto mi_id = mi[id]; + bool changed = m_prime_id < mi_id; // `false` if both are -inf + if (changed) { + auto m_prime_exp = exp2f(m_prime_id - mi_id); + out_rescale[id] = m_prime_exp; + s_prime[id] *= m_prime_exp; + } else { + // Only when bias is enabled, it's possible that all the first values + // of attention are masked to `-inf`. In that case we want to avoid + // `nan = exp2f(-inf - (-inf))` so we temporarily set `mi` to 0 + if (kSupportsBias && + mi_id == -cutlass::platform::numeric_limits::infinity()) { + restore_mi_to_minus_inf = true; + mi[id] = 0.0f; + } + out_rescale[id] = 1.0f; + } + } + __syncthreads(); // Update output fragments + if (kKeepOutputInRF && !is_first) { + accum_t line_rescale; + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { line_rescale = out_rescale[accum_m]; }, + [&](int accum_m, int accum_n, int idx) { + frag_o[idx] = frag_o[idx] * line_rescale; + }, + [&](int accum_m) {}); + } + // Update accum_m, accum_n, ... + { + accum_t mi_row, total_row; + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { mi_row = mi[accum_m]; }, + [&](int accum_m, int accum_n, int idx) { + frag[idx] = + (accum_n < max_col) ? exp2f(frag[idx] - mi_row) : accum_t(0.0); + }, + [&](int accum_m) {}); + LambdaIterator::iterateRows( + lane_offset, + [&](int accum_m) { total_row = 0.0; }, + [&](int accum_m, int accum_n, int idx) { total_row += frag[idx]; }, + [&](int accum_m) { + if (LambdaIterator::reduceSameRow( + lane_id, total_row, [](accum_t a, accum_t b) { + return a + b; + })) { + // NOTE: we could atomically add `total_row` to `s_prime`, but + // it's faster (and deterministic) to avoid atomics here + addition_storage + [accum_m + kQueriesPerBlock * tile_offset.column()] = + total_row; + } + }); + } + __syncthreads(); + if (lane_id < kLinesPerWarp) { + int id = warp_id * kLinesPerWarp + lane_id; + accum_t total_row = s_prime[id]; + if (restore_mi_to_minus_inf) { + // Restore `mi`, see above when we set `restore_mi_to_minus_inf=true` + mi[id] = -cutlass::platform::numeric_limits::infinity(); + } else { + m_prime[id] = mi[id]; + } + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < MM0::MmaCore::WarpCount::kN; ++i) { + total_row += addition_storage[id + kQueriesPerBlock * i]; + } + s_prime[id] = total_row; + } + } + + static CUTLASS_DEVICE int8_t lane_id() { + return threadIdx.x; + } + static CUTLASS_DEVICE int8_t warp_id() { + return threadIdx.y; + } + static CUTLASS_DEVICE int16_t thread_id() { + return threadIdx.x + threadIdx.y * blockDim.x; + } +}; + +template +__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) + attention_kernel_batched_impl(typename AK::Params p) { + if (!p.advance_to_block()) { + return; + } + AK::attention_kernel(p); +} + +template +__global__ void __launch_bounds__(AK::kNumThreads, AK::kMinBlocksPerSm) + attention_kernel_batched(typename AK::Params params); diff --git a/examples/41_fused_multi_head_attention/piped_subprocess.py b/examples/41_fused_multi_head_attention/piped_subprocess.py new file mode 100644 index 0000000000..82351f492c --- /dev/null +++ b/examples/41_fused_multi_head_attention/piped_subprocess.py @@ -0,0 +1,144 @@ +################################################################################################# +# +# Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from typing import List +import torch +import subprocess +import sys +import tempfile +import os +import numpy as np + + +TORCH_DTYPE_NAME = { + torch.float32: "f32", + torch.float16: "f16", + torch.bfloat16: "b16" +} +NAME_TORCH_DTYPE = {v: k for k, v in TORCH_DTYPE_NAME.items()} + +def _tensor_from_storage(tensor: torch.Tensor, dtype) -> torch.Tensor: + # PyTorch >= 2.0 + if hasattr(tensor, 'untyped_storage'): + return torch.tensor([], dtype=dtype).set_(tensor.untyped_storage()) + return torch.tensor([], dtype=dtype).set_(tensor.storage().untyped()) + +class PipedSubprocess: + def __init__(self, binary: str) -> None: + self.binary = binary + self.tempdir_ctx = tempfile.TemporaryDirectory() + + def __enter__(self) -> "PipedSubprocess": + self.subp = subprocess.Popen(self.binary, stdin=subprocess.PIPE, stdout=subprocess.PIPE, stderr=sys.stderr, text=True, bufsize=0) + self.tempdir = self.tempdir_ctx.__enter__() + self.file_counter = 0 + return self + + def __exit__(self, exc_type, exc_val, exc_tb) -> None: + self.tempdir_ctx.__exit__(exc_type, exc_val, exc_tb) + + def temp_filename(self, suffix: str) -> str: + self.file_counter += 1 + return os.path.join(self.tempdir, f"{self.file_counter}{suffix}") + + def write(self, *args) -> None: + for a in args: + self.subp.stdin.write(str(a) + " ") + + def writeTensor(self, tensor: torch.Tensor, name: str, stride_names: List[str]) -> None: + print(f"Py ->C++: {TORCH_DTYPE_NAME[tensor.dtype]}:{name}") + tensor_u8 = _tensor_from_storage(tensor, torch.uint8) + self.write("tensor_begin", f"{TORCH_DTYPE_NAME[tensor.dtype]}:{name}", tensor_u8.shape[0]) + filename = self.temp_filename(f"{name}.tensor") + assert tensor.storage_offset() == 0 + with open(filename, "wb+") as fd: + fd.write(bytes(tensor_u8.numpy())) + self.write("file", filename) + self.write("tensor_end") + + for stride_name, stride_value in zip(stride_names, tensor.stride()): + self.write(stride_name, stride_value) + + def readTensor(self, name, stride_name, shape) -> torch.Tensor: + tmpfile = self.temp_filename(f"{name}.tensor") + self.write("tmpfile", tmpfile) + + self.readExpect("tensor_begin") + dtype_str, name = self.read().split(":") + print(f"C++->Py : {dtype_str}:{name}") + u8len = int(self.read()) + dtype = NAME_TORCH_DTYPE[dtype_str] + + self.readExpect("file") + self.readExpect(tmpfile) + + with open(tmpfile, "rb") as fd: + data = fd.read(u8len) + # `np.array` is not strictly needed, but avoids a torch warning + tensor_u8 = torch.frombuffer(np.array(data), dtype=torch.uint8, count=u8len) + self.readExpect("tensor_end") + + tensor = _tensor_from_storage(tensor_u8, dtype) + strides = [] + for sn in stride_name: + self.readExpect(sn) + strides.append(int(self.read())) + if len(strides) != shape: + strides.append(1) + assert len(strides) == len(shape), name + return torch.as_strided(tensor, shape, strides) + + def readNamed(self, name: str): + self.readExpect(name) + return self.read() + + def readExpect(self, what: str) -> None: + r = self.read() + if r != what: + raise ValueError(f"Read {r} but expected {what}") + + def read(self): + read_all = [] + # Skip initial whitespace + while True: + r = self.subp.stdout.read(1) + if r not in [' ', "\n"]: + read_all.append(r) + break + # Read data + while True: + r = self.subp.stdout.read(1) + if r in [' ', "\n"]: + break + read_all.append(r) + return ''.join(read_all) + diff --git a/examples/41_fused_multi_head_attention/transform/tile_smem_loader.h b/examples/41_fused_multi_head_attention/transform/tile_smem_loader.h new file mode 100644 index 0000000000..2db928a84c --- /dev/null +++ b/examples/41_fused_multi_head_attention/transform/tile_smem_loader.h @@ -0,0 +1,90 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/numeric_types.h" +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/transform/threadblock/predicated_tile_iterator.h" +#include "cutlass/transform/threadblock/regular_tile_iterator.h" + +template < + typename scalar_t, // scalar type + typename ThreadblockTileShape, // size of tile to load + int Threads, // number of participating threads + int ElementsPerAccess> // thread access width in elements +class TileSmemLoader { + public: + using SmemTile = + cutlass::AlignedBuffer; + + using ThreadMap = cutlass::transform::PitchLinearStripminedThreadMap< + cutlass::layout::PitchLinearShape< + ThreadblockTileShape::kColumn, // contiguous + ThreadblockTileShape::kRow>, // strided + Threads, // Threads + ElementsPerAccess>; // ElementsPerAccess + + using GmemTileIterator = + cutlass::transform::threadblock::PredicatedTileIterator< + ThreadblockTileShape, // Shape + scalar_t, // Element + cutlass::layout::RowMajor, // Layout + 0, // AdvanceRank + ThreadMap>; // ThreadMap + + using SmemTileIterator = cutlass::transform::threadblock::RegularTileIterator< + ThreadblockTileShape, // Shape + scalar_t, // Element + cutlass::layout::RowMajor, // Layout + 0, // AdvanceRank + ThreadMap>; // ThreadMap + + using Fragment = typename GmemTileIterator::Fragment; + + /// load a tile from global memory into shared memory + CUTLASS_DEVICE + static void load( + GmemTileIterator tile_load_iter, + SmemTileIterator tile_store_iter) { + Fragment tb_frag; + tb_frag.clear(); + tile_load_iter.load(tb_frag); + tile_store_iter.store(tb_frag); + + __syncthreads(); + } +}; diff --git a/examples/42_ampere_tensorop_group_conv/CMakeLists.txt b/examples/42_ampere_tensorop_group_conv/CMakeLists.txt new file mode 100644 index 0000000000..d470548cdc --- /dev/null +++ b/examples/42_ampere_tensorop_group_conv/CMakeLists.txt @@ -0,0 +1,36 @@ + +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + + +cutlass_example_add_executable( + 42_ampere_tensorop_group_conv + ampere_tensorop_group_conv.cu + ) + diff --git a/examples/42_ampere_tensorop_group_conv/ampere_tensorop_group_conv.cu b/examples/42_ampere_tensorop_group_conv/ampere_tensorop_group_conv.cu new file mode 100644 index 0000000000..120f04b649 --- /dev/null +++ b/examples/42_ampere_tensorop_group_conv/ampere_tensorop_group_conv.cu @@ -0,0 +1,706 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/** +This example shows how to run group convolution kernels using functions and data structures +provided by CUTLASS using tensor cores; which we run on a NVIDIA Ampere GPU. + +There are 2 group conv mode: + 1. cutlass::conv::GroupMode::kSingleGroup + This mode is for large K problem size: k_per_group (K/groups) equals or larger than + threadblock_tile_N. One or multiple threadblocks calculate data of one group. + 2. cutlass::conv::GroupMode::kMultipleGroup + This mode is for small K problem size: k_per_group (K/groups) is smaller than threadblock_tile_N. + One threadblock will calculate data from more than one group. + +Function profile_convolution_selecter() shows how to choose kernel with different group mode according +to problem size and threadblock_tile size. +*/ + +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" +#include "cutlass/conv/kernel/default_conv2d_group_fprop.h" +#include "cutlass/conv/device/implicit_gemm_convolution.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/convolution.h" +#include "cutlass/util/reference/device/convolution.h" +#include "cutlass/util/tensor_view_io.h" + +#include "helper.h" + +// The code section below describes datatype for input, output tensors and computation between +// elements +using ElementAccumulator = float; // Data type of accumulator +using ElementComputeEpilogue = float; // Data type of epilogue computation (alpha, beta) +using ElementInputA = cutlass::half_t; // Data type of elements in input tensor +using ElementInputB = cutlass::half_t; // Data type of elements in input tensor +using ElementOutput = float; // Data type of elements in output tensor + +using LayoutInputA = cutlass::layout::TensorNHWC; +using LayoutInputB = cutlass::layout::TensorNHWC; +using LayoutOutput = cutlass::layout::TensorNHWC; + +// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM +using MMAOp = cutlass::arch::OpClassTensorOp; + +// This code section describes CUDA SM architecture number +using SmArch = cutlass::arch::Sm80; + +// This code section describes the tile size a thread block will compute +using ThreadblockShape = cutlass::gemm::GemmShape<64, 64, 64>; // Threadblock tile shape + +// This code section describes tile size a warp will compute +using WarpShape = cutlass::gemm::GemmShape<32, 32, 64>; // Warp tile shape + +// This code section describes the size of MMA op +using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; // TensorCore instruction shape + +// This code section describes how threadblocks are scheduled on GPU +using SwizzleThreadBlock = cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>; + +// Number of pipelines you want to use +constexpr int NumStages = 3; + +// This code section describes the epilogue part of the kernel, we use default value +using EpilogueOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, // Data type of output matrix. + 128 / cutlass::sizeof_bits::value, // The number of elements per vectorized. + // memory access. This becomes the vector width of + // math instructions in the epilogue too. + ElementAccumulator, // Data type of accumulator + ElementComputeEpilogue>; // Data type for alpha/beta in linear combination + +// Analytic kernel and operation for single group problem size +using AnalyticSingleGroupKernel = typename cutlass::conv::kernel::DefaultConv2dGroupFprop< + ElementInputA, LayoutInputA, + ElementInputB, LayoutInputB, + ElementOutput, LayoutOutput, + ElementAccumulator, + MMAOp, + SmArch, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOp, + SwizzleThreadBlock, + NumStages, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::GroupMode::kSingleGroup, + cutlass::conv::IteratorAlgorithm::kAnalytic +>::Kernel; +using AnalyticSingleGroupOperation = cutlass::conv::device::ImplicitGemmConvolution; + +// Analytic kernel and operation for multiple group problem size +using AnalyticMultipleGroupKernel = typename cutlass::conv::kernel::DefaultConv2dGroupFprop< + ElementInputA, LayoutInputA, + ElementInputB, LayoutInputB, + ElementOutput, LayoutOutput, + ElementAccumulator, + MMAOp, + SmArch, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOp, + SwizzleThreadBlock, + NumStages, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::GroupMode::kMultipleGroup, + cutlass::conv::IteratorAlgorithm::kAnalytic +>::Kernel; +using AnalyticMultipleGroupOperation = cutlass::conv::device::ImplicitGemmConvolution; + +// Optimized kernel and operation for single group problem size +using OptimizedSingleGroupKernel = typename cutlass::conv::kernel::DefaultConv2dGroupFprop< + ElementInputA, LayoutInputA, + ElementInputB, LayoutInputB, + ElementOutput, LayoutOutput, + ElementAccumulator, + MMAOp, + SmArch, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOp, + SwizzleThreadBlock, + NumStages, + cutlass::arch::OpMultiplyAdd, + cutlass::conv::GroupMode::kSingleGroup, + cutlass::conv::IteratorAlgorithm::kOptimized +>::Kernel; +using OptimizedSingleGroupOperation = cutlass::conv::device::ImplicitGemmConvolution; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + cutlass::Tensor4DCoord input_size; + cutlass::Tensor4DCoord filter_size; + cutlass::Tensor4DCoord padding; + cutlass::MatrixCoord conv_stride; + cutlass::MatrixCoord dilation; + int groups; + bool reference_check; + bool measure_performance; + int iterations; + ElementComputeEpilogue alpha; + ElementComputeEpilogue beta; + bool optimized; + std::string tag; + + Options(): + help(false), + input_size(1, 32, 32, 32), + filter_size(32, 3, 3, 32), + padding(1, 1, 1, 1), + conv_stride(1, 1), + dilation(1, 1), + groups(1), + reference_check(false), + measure_performance(false), + iterations(20), + alpha(1), + beta(0), + optimized(false) { } + + // Verify the problem size is compatible with the CUTLASS Convolution implementation. + bool valid() { + + // + // CUTLASS attempts to load 128b vectors of cutlass::half_t (F16) elements. Consequently, + // all pointers, strides, and tensor extents must be divisible by 8 elements. + // + int const kAlignment = 8; + + if ((input_size.c() % kAlignment) || + (filter_size.n() % kAlignment)) { + + // misaligned tensors + return false; + } + + // Invalid padding + if ((padding.h() != filter_size.h() / 2) || + (padding.w() != filter_size.w() / 2)) { + + return false; + } + + return true; + } + + /// Updates input and filter sizes + void update( + cutlass::Tensor4DCoord input_size, + cutlass::Tensor4DCoord filter_size) { + + this->input_size = input_size; + this->filter_size = filter_size; + + padding.n() = filter_size.h() / 2; + padding.h() = filter_size.h() / 2; + padding.w() = filter_size.w() / 2; + padding.c() = filter_size.w() / 2; + } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + } + + if (cmd.check_cmd_line_flag("ref-check")) { + reference_check = true; + } + + if (cmd.check_cmd_line_flag("perf-check")) { + measure_performance = true; + } + + if (cmd.check_cmd_line_flag("optimized")) { + optimized = true; + } + + cmd.get_cmd_line_argument("n", input_size.n()); + cmd.get_cmd_line_argument("h", input_size.h()); + cmd.get_cmd_line_argument("w", input_size.w()); + cmd.get_cmd_line_argument("c", input_size.c()); + + cmd.get_cmd_line_argument("k", filter_size.n()); + cmd.get_cmd_line_argument("r", filter_size.h()); + cmd.get_cmd_line_argument("s", filter_size.w()); + + cmd.get_cmd_line_argument("g", groups); + filter_size.c() = input_size.c() / groups; + + cmd.get_cmd_line_argument("u", conv_stride.row()); + cmd.get_cmd_line_argument("v", conv_stride.column()); + + cmd.get_cmd_line_argument("alpha", alpha); + cmd.get_cmd_line_argument("beta", beta); + + cmd.get_cmd_line_argument("iterations", iterations); + cmd.get_cmd_line_argument("tag", tag); + + if (filter_size.h() == 3 && filter_size.w() == 3) { + padding = {1, 1, 1, 1}; + } + else { + filter_size.h() = 1; + filter_size.w() = 1; + padding = {0, 0, 0, 0}; + } + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "42_ampere_tensorop_group_conv example\n\n" + << " This example uses Ampere's Tensor Core operators on F16 data types to compute\n" + << " forward grouped convolution on tensors of layout NHWC.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement.\n\n" + << " --n= Input tensor extent N\n" + << " --h= Input tensor extent H\n" + << " --w= Input tensor extent W\n" + << " --c= Input tensor extent C\n" + << " --k= Filter extent K\n" + << " --r= Filter extent R\n" + << " --s= Filter extent S\n\n" + << " --g= Conv groups G\n\n" + << " --u= Conv stride_h\n\n" + << " --v= Conv stride_w\n\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --ref-check If set (true), reference check is computed\n" + << " --perf-check If set (true), performance is measured.\n" + << " --optimized If set (true), use optimized kernel, otherwise use analytic kernel.\n" + << " --iterations= Number of profiling iterations to perform.\n" + << " --tag= String to replicate across the first column in the results table\n"; + + out << "\n\nExamples:\n\n" + << "$ ./examples/42_ampere_tensorop_group_conv/42_ampere_tensorop_group_conv --n=4 --h=16 --w=16 --c=256 --k=128 --r=3 --s=3 --g=8 --ref-check\n\n" + << "$ ./examples/42_ampere_tensorop_group_conv/42_ampere_tensorop_group_conv --n=4 --h=16 --w=16 --c=256 --k=128 --r=3 --s=3 --g=2 --ref-check\n\n" + << "$ ./examples/42_ampere_tensorop_group_conv/42_ampere_tensorop_group_conv --n=4 --h=16 --w=16 --c=256 --k=128 --r=3 --s=3 --g=2 --ref-check --optimized\n\n"; + + return out; + } + + /// Computes the output tensor size (NPQK) + cutlass::Tensor4DCoord output_size() const { + return cutlass::Tensor4DCoord( + input_size.n(), + (input_size.h() + padding.n() + padding.h() - filter_size.h()) / conv_stride.row() + 1, + (input_size.w() + padding.w() + padding.c() - filter_size.w()) / conv_stride.column() + 1, + filter_size.n()); + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const { + + // Number of multiply-adds = NPQK * CRS + int64_t fmas = output_size().product() * int64_t(filter_size.h() * filter_size.w() * filter_size.c()); + + // Two flops per multiply-add + return 2.0 * double(fmas) / double(1.0e9) / runtime_s; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +struct Result { + double runtime_ms; + double gflops; + cutlass::Status status; + cutlass::Status reference_check; + cudaError_t error; + + Result(): + runtime_ms(0), + gflops(0), + status(cutlass::Status::kSuccess), + reference_check(cutlass::Status::kInvalid), + error(cudaSuccess) { } + + static std::ostream & print_header(std::ostream &out, Options const &options) { + + if (!options.tag.empty()) { + out << "Name,"; + } + + out << "Layer,N,H,W,C,K,R,S,G,Runtime,GFLOPs"; + + return out; + } + + std::ostream & print(std::ostream &out, int idx, Options const &options) { + + if (!options.tag.empty()) { + out << options.tag << ","; + } + + out + << "conv_" << idx << "," + << options.input_size.n() << "," + << options.input_size.h() << "," + << options.input_size.w() << "," + << options.input_size.c() << "," + << options.filter_size.n() << "," + << options.filter_size.h() << "," + << options.filter_size.w() << "," + << options.groups << "," + << runtime_ms << "," + << gflops; + + return out; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Runs one benchmark +template +Result profile_convolution(Options const &options) { + + Result result; + + // + // Allocate host-device tensors using the CUTLASS Utilities. + // + + cutlass::HostTensor tensor_a(options.input_size); + cutlass::HostTensor tensor_b(options.filter_size); + cutlass::HostTensor tensor_c(options.output_size()); + cutlass::HostTensor tensor_d(options.output_size()); + cutlass::HostTensor tensor_ref_d(options.output_size()); + + // + // Initialize tensors + // + + // Fill tensor A on host with uniform-distribution random data + cutlass::reference::host::TensorFillRandomUniform( + tensor_a.host_view(), + 1, + ElementInputA(7), + ElementInputA(-8), + 0); + + // Fill tensor B on host with uniform-distribution random data + cutlass::reference::host::TensorFillRandomUniform( + tensor_b.host_view(), + 1, + ElementInputB(7), + ElementInputB(-8), + 0); + + // Fill tensor C on host with uniform-distribution random data + cutlass::reference::host::TensorFillRandomUniform( + tensor_c.host_view(), + 1, + ElementOutput(7), + ElementOutput(-8), + 0); + + // Fill tensor D on host with zeros + cutlass::reference::host::TensorFill( + tensor_d.host_view()); + + // Fill tensor D for reference on host with zeros + cutlass::reference::host::TensorFill( + tensor_ref_d.host_view()); + + // Copy data from host to GPU + tensor_a.sync_device(); + tensor_b.sync_device(); + tensor_c.sync_device(); + tensor_d.sync_device(); + tensor_ref_d.sync_device(); + + // + // Define arguments for CUTLASS Convolution + // + + cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation; + + // Split K dimension into 1 partitions + int split_k_slices = 1; + + // Construct Conv2dProblemSize with user defined output size + cutlass::conv::Conv2dProblemSize problem_size( + options.input_size, + options.filter_size, + options.padding, + options.conv_stride, + options.dilation, + options.output_size(), + mode, + split_k_slices, + options.groups + ); + + // Construct Conv2dOperation::Argument structure with conv2d + // problem size, data pointers, and epilogue values + typename Conv2dOperation::Arguments arguments{ + problem_size, + tensor_a.device_ref(), + tensor_b.device_ref(), + tensor_c.device_ref(), + tensor_d.device_ref(), + {options.alpha, options.beta}, + }; + + // + // Initialize CUTLASS Convolution + // + + Conv2dOperation implicit_gemm_op; + + size_t workspace_size = implicit_gemm_op.get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + result.status = implicit_gemm_op.can_implement(arguments); + CUTLASS_CHECK(result.status); + + result.status = implicit_gemm_op.initialize(arguments, workspace.get()); + CUTLASS_CHECK(result.status); + + // + // Launch initialized CUTLASS kernel + // + result.status = implicit_gemm_op(); + + CUTLASS_CHECK(result.status); + + // + // Optional reference check + // + + if (options.reference_check) { + std::cout << "Verification on device...\n"; + + // Compute with reference implementation + cutlass::reference::device::Conv2dFprop< + ElementInputA, + LayoutInputA, + ElementInputB, + LayoutInputB, + ElementOutput, + LayoutOutput, + ElementComputeEpilogue, + ElementAccumulator, + cutlass::NumericConverter + >( + problem_size, + tensor_a.device_ref(), + tensor_b.device_ref(), + tensor_c.device_ref(), + tensor_ref_d.device_ref(), + options.alpha, + options.beta + ); + + tensor_ref_d.sync_host(); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + tensor_d.sync_host(); + + bool passed = cutlass::reference::host::TensorEquals( + tensor_d.host_view(), + tensor_ref_d.host_view()); + + if (!passed) { + result.reference_check = cutlass::Status::kErrorInternal; + std::cout << "ERROR - results miscompared.\n"; + } else { + result.reference_check = cutlass::Status::kSuccess; + std::cout << "Passed.\n"; + } + } else { + result.reference_check = cutlass::Status::kInvalid; + } + + // + // Performance measurement + // + + if (options.measure_performance) { + + cudaEvent_t events[2]; + + for (auto & event : events) { + result.error = cudaEventCreate(&event); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + } + + // Record an event at the start of a series of convolution operations. + result.error = cudaEventRecord(events[0]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Launch a sequence of implicit GEMM operations on the device + for (int iteration = 0; iteration < options.iterations; ++iteration) { + result.status = implicit_gemm_op(); + CUTLASS_CHECK(result.status); + } + + // Record an event when the convolutions have been launched. + result.error = cudaEventRecord(events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Wait for work on the device to complete. + result.error = cudaEventSynchronize(events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Measure elapsed runtime + float runtime_ms = 0; + result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Print average runtime and GFLOPs. + result.runtime_ms = double(runtime_ms) / double(options.iterations); + result.gflops = options.gflops(result.runtime_ms / 1000.0); + + // Cleanup + for (auto event : events) { + (void)cudaEventDestroy(event); + } + } + + return result; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +Result profile_convolution_selecter(Options const &options) { + int k_per_group = options.filter_size.n() / options.groups; + + // In group conv, if k_per_group < threadblock_N, one Threadblock will calculate multiple groups + if (k_per_group < ThreadblockShape::kN) { // MultipleGroup mode + if (options.optimized) { + std::cerr << "Invalid problem: optimized group conv kernel doesn't support MultipleGroup (one CTA calculate multiple groups) mode" << std::endl; + exit(-1); + } else { + std::cout << "Select AnalyticMultipleGroupOperation\n"; + return profile_convolution(options); + } + } else { // SingleGroup mode + if (options.optimized) { + std::cout << "Select OptimizedSingleGroupOperation\n"; + return profile_convolution(options); + } else { + std::cout << "Select AnalyticSingleGroupOperation\n"; + return profile_convolution(options); + } + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + bool notSupported = false; + + // Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0. + // + // CUTLASS must be compiled with CUDA 11 Toolkit to run Conv2dFprop examples. + if (!(__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0))) { + std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl; + notSupported = true; + } + + cudaDeviceProp props; + CUDA_CHECK(cudaGetDeviceProperties(&props, 0)); + + if (!(props.major > 8 || (props.major == 8 && props.minor >= 0))) { + std::cerr << "Ampere Tensor Ops must be run on a machine with compute capability at least 80." + << std::endl; + notSupported = true; + } + + if (notSupported) { + return 0; + } + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // Execute one problem size + if (!options.valid()) { + std::cerr << "Invalid problem." << std::endl; + return -1; + } + + Result result = profile_convolution_selecter(options); + + Result::print_header(std::cout, options) << std::endl; + result.print(std::cout, 1, options) << std::endl; + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/43_ell_block_sparse_gemm/CMakeLists.txt b/examples/43_ell_block_sparse_gemm/CMakeLists.txt new file mode 100644 index 0000000000..0676c7bd31 --- /dev/null +++ b/examples/43_ell_block_sparse_gemm/CMakeLists.txt @@ -0,0 +1,34 @@ +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +cutlass_example_add_executable( + 43_ell_block_sparse_gemm + ell_block_sparse_gemm.cu + ) + diff --git a/examples/43_ell_block_sparse_gemm/ell_block_sparse_gemm.cu b/examples/43_ell_block_sparse_gemm/ell_block_sparse_gemm.cu new file mode 100644 index 0000000000..52d2d0cbfa --- /dev/null +++ b/examples/43_ell_block_sparse_gemm/ell_block_sparse_gemm.cu @@ -0,0 +1,740 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Block-Ell sparse gemm example. + + This example performs a Sparse-matrix dense-matrix multiplication (SpMM) operation. + Matrix A is stored in the Blocked-Ellpack (Blocked-ELL) storage format. + Details about the Blocked-Ellpack (Blocked-ELL) storage format can be found here: + https://docs.nvidia.com/cuda/cusparse/index.html#cusparse-generic-spmat-create-blockedell + Whereas matrix B is a dense matrix. + + Blocked-Ellpack or Blocked-ELL storage format comprises of two matrices. + First is a packed matrix (ellValue matrix) that stores non-zero values in consecutive blocks, + represented by tensor_a in this example. Second is a matrix of indices (ellColInd matrix), + represented by tensor_ell_idx in this example, that represent the column indices of the + corresponding non-zero blocks. All rows in the matrices must have the same number of blocks. + ellColInd can contain -1 values for indicating empty blocks. These matrices store elements in + row-major order. + + Description of parameters and tensors used to represent the Blocked-Ellpack (ELL) format + for this example: + a_rows - Rows in the sparse matrix. + a_cols - Colums in the sparse matrix. + a_ell_blocksize - Size of the ELL-Blocks. + a_ell_num_columns - Number of columns in the Blocked-Ellpack format (ellValue columns) + tensor_a - ellValue matrix, whose size is (a_rows * a_ell_num_columns) + tensor_ell_idx - Blocked-ELL Column indices (ellColInd), whose size is + (a_rows / a_ell_blocksize) * (a_ell_num_columns / a_ell_blocksize) + tensor_b - Input dense matrix whose size is (a_cols * n) + tensor_c/tensor_d - Output dense matrix whose size is (a_rows * n) + {a_rows, n, a_cols} - Problem size + +*/ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include +#include +#include +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/gemm_grouped.h" +#include "cutlass/gemm/kernel/default_gemm_grouped.h" +#include "cutlass/gemm/device/ell_gemm.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/host_uncompress.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Result structure +struct Result { + + double runtime_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + // + // Methods + // + + Result( + double runtime_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess + ): + runtime_ms(runtime_ms), gflops(gflops), status(status), error(error), passed(true) { } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + bool reference_check; + int iterations; + int cuda_streams; + int a_rows, n, a_cols; + int a_ell_num_columns; + int a_ell_blocksize; + int a_base; + float alpha; + float beta; + + // + // Methods + // + + Options(): + help(false), + reference_check(true), + iterations(20), + cuda_streams(0), + a_rows(1024), + n(1024), + a_cols(1024), + a_ell_num_columns(512), + a_ell_blocksize(16), + a_base(0), + alpha(1), + beta() + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + } + + cmd.get_cmd_line_argument("alpha", alpha, 1.0f); + cmd.get_cmd_line_argument("beta", beta, 0.0f); + cmd.get_cmd_line_argument("iterations", iterations, 20); + cmd.get_cmd_line_argument("streams", cuda_streams, 0); + cmd.get_cmd_line_argument("reference-check", reference_check, true); + + cmd.get_cmd_line_argument("a_rows", a_rows, 1024); + cmd.get_cmd_line_argument("n", n, 1024); + cmd.get_cmd_line_argument("a_cols", a_cols, 1024); + + cmd.get_cmd_line_argument("a_ell_num_columns", a_ell_num_columns, 512); + cmd.get_cmd_line_argument("a_ell_blocksize", a_ell_blocksize, 16); + cmd.get_cmd_line_argument("a_base", a_base, 0); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "43_ell_block_sparse_gemm\n\n" + << " This example profiles the performance of a ELL block sparse GEMM kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement.\n\n" + << " --a_rows= Sets the number of the rows of the sparse matrix.\n" + << " --n= Sets the N dimension.\n" + << " --a_cols= Sets the number of columns of the sparse matrix.\n" + << " --a_ell_num_columns= Sets the actual number of columns of the Blocked-Ellpack format.\n" + << " --a_ell_blocksize= Sets the size of the ELL-Block.\n" + << " --a_base= Sets the base index.\n" + << " --alpha= Epilogue scalar alpha (real part)\n" + << " --beta= Epilogue scalar beta (real part)\n\n" + << " --iterations= Number of profiling iterations to perform.\n" + << " --reference-check= If true, performs reference check.\n"; + + out << "\n\nExamples:\n\n" + + << "# Runs a 1024x1024x1024 ELL block sparse GEMM with 16x16 block size and actual 512 non-zero columns in A operand\n" + << "$ ./examples/43_ell_block_sparse_gemm/43_ell_block_sparse_gemm --a_rows=1024 --n=1024 --a_cols=1024 --a_ell_num_columns=512 --a_ell_blocksize=16\n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const { + + // Number of real-valued multiply-adds + int64_t fmas = (int64_t)a_rows * (int64_t)a_cols * (int64_t)n; + + // Two flops per multiply-add + return 2.0 * double(fmas) / double(1.0e9) / runtime_s; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class Testbed { +public: + + // + // Type definitions + // + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; + using ElementAccumulator = typename Gemm::ElementAccumulator; + + using EpilogueOutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp; + using ElementCompute = typename EpilogueOutputOp::ElementCompute; + + using LayoutA = typename Gemm::LayoutA; + using LayoutB = typename Gemm::LayoutB; + using LayoutC = typename Gemm::LayoutC; + + using MatrixCoord = typename LayoutC::TensorCoord; + +private: + + // + // Data members + // + + Options options; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + cutlass::Distribution::Kind init_ELL; + uint32_t seed; + + cutlass::HostTensor tensor_a; + cutlass::HostTensor tensor_b; + cutlass::HostTensor tensor_c; + cutlass::HostTensor tensor_d; + + cutlass::HostTensor tensor_a_uncompressed; + cutlass::HostTensor reference_d; + + cutlass::HostTensor tensor_ell_idx; + +public: + + // + // Methods + // + + Testbed( + Options const &options_, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_ELL_ = cutlass::Distribution::Uniform, + uint32_t seed_ = 3080 + ): + options(options_), init_A(init_A_), init_B(init_B_), init_C(init_C_), init_ELL(init_ELL_), seed(seed_) { } + +private: + + /// Helper to initialize a tensor view + template + void initialize_tensor_( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint32_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + Element scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + if (cutlass::sizeof_bits::value <= 16) { + scope_max = 5; + scope_min = -5; + } + else { + scope_max = 8; + scope_min = -8; + } + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian( + view, seed, Element(), Element(0.5f)); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + // Fill with increasing elements + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity(), Element(1), Element()); + } else { + + // Fill with all 1s + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity(), Element(), Element(1)); + } + } + + /// Initializes data structures + void initialize_() { + tensor_a.resize(cutlass::make_Coord(options.a_rows, options.a_ell_num_columns)); + tensor_b.resize(cutlass::make_Coord(options.a_cols, options.n)); + tensor_c.resize(cutlass::make_Coord(options.a_rows, options.n)); + tensor_d.resize(cutlass::make_Coord(options.a_rows, options.n)); + + tensor_a_uncompressed.resize(cutlass::make_Coord(options.a_rows, options.a_cols)); + reference_d.resize(cutlass::make_Coord(options.a_rows, options.n)); + + tensor_ell_idx.resize(cutlass::make_Coord(options.a_rows / options.a_ell_blocksize, + options.a_ell_num_columns / options.a_ell_blocksize)); + + // + // Initialize the problems of the workspace + // + + initialize_tensor_(tensor_a.host_view(), init_A, seed * 2021); + initialize_tensor_(tensor_b.host_view(), init_B, seed * 2022); + initialize_tensor_(tensor_c.host_view(), init_C, seed * 2023); + + if (init_ELL == cutlass::Distribution::Uniform) { + cutlass::reference::host::TensorFillRandomEllIdx( + tensor_ell_idx.host_view(), seed, + options.a_rows / options.a_ell_blocksize, + options.a_ell_num_columns / options.a_ell_blocksize, + options.a_cols / options.a_ell_blocksize); + + } else { + for(int i = 0; i < options.a_rows / options.a_ell_blocksize; ++i) { + for(int j = 0; j < options.a_ell_num_columns / options.a_ell_blocksize; ++j) { + tensor_ell_idx.at({i, j}) = j+3; + } + } + } + + tensor_a.sync_device(); + tensor_b.sync_device(); + tensor_c.sync_device(); + tensor_d.sync_device(); + tensor_ell_idx.sync_device(); + } + + /// Verifies the result is a GEMM + bool verify_() { + + bool passed = true; + + tensor_d.sync_host(); + + cutlass::uncompress_ell_block_sparse( + tensor_a_uncompressed.host_ref(), + tensor_a.host_ref(), + tensor_ell_idx.host_ref(), + options.a_rows, + options.a_cols, + options.a_ell_num_columns, + options.a_ell_blocksize + ); + + cutlass::reference::host::Gemm< + typename Gemm::ElementA, typename Gemm::LayoutA, + typename Gemm::ElementB, typename Gemm::LayoutB, + typename Gemm::ElementC, typename Gemm::LayoutC, + ElementCompute, + ElementAccumulator, typename Gemm::Operator> + reference_gemm; + + reference_gemm( + {options.a_rows, options.n, options.a_cols}, + options.alpha, + tensor_a_uncompressed.host_ref(), + tensor_b.host_ref(), + options.beta, + reference_d.host_ref(), + ElementAccumulator(0) + ); + + // Reference check + passed = cutlass::reference::host::TensorEquals(tensor_d.host_view(), reference_d.host_view()); + + if (!passed) { + std::cerr << "\n***\nError - problem failed the QA check\n***\n" << std::endl; + + std::stringstream fname; + + fname << "error_43_ell_block_sparse_gemm" + << "mnk_" + << options.a_rows << "x" + << options.n << "x" + << options.a_cols << "_" + << options.a_ell_num_columns << "_" + << options.a_ell_blocksize << ".txt"; + + std::cout << fname.str() << std::endl; + + std::ofstream results(fname.str()); + + results + << "alpha: " << ElementCompute(options.alpha) << "\n" + << "beta: " << ElementCompute(options.beta) << "\n" + << "block size: " << options.a_ell_blocksize << "\n" + << "\nA:\n" << tensor_a.host_view() << "\n" + << "\nA Ell Index:\n" << tensor_ell_idx.host_view() << "\n" + << "\nB:\n" << tensor_b.host_view() << "\n" + << "\nC:\n" << tensor_c.host_view() << "\n" + << "\nD reference:\n" << reference_d.host_view() << "\n" + << "\nD computed:\n" << tensor_d.host_view() << "\n"; + + + return passed; + } + + return passed; + } + +public: + + /// Returns the number of threadblocks to launch if the kernel can run on the target + /// device. Otherwise, returns zero. + bool sufficient() const { + // + // Determine SMEM requirements and waive if not satisfied + // + + size_t smem_size = sizeof(typename Gemm::GemmKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDevice() API call failed."); + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + throw std::runtime_error("cudaGetDeviceProperties() failed"); + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + return false; + } + + return true; + } + + /// Executes a BlockedEll SpMM kernel and measures runtime. + Result profile() { + + Result result; + + // Early exit + if (!sufficient()) { + std::cout << "Active CUDA device lacks hardware resources to run CUTLASS BlockedEll SpMM kernel." << std::endl; + return result; + } + + result.passed = false; + + // Initialize the problem + initialize_(); + + // Configure the GEMM arguments + typename EpilogueOutputOp::Params epilogue_op(options.alpha, options.beta); + + // Configure GEMM arguments + typename Gemm::Arguments args( + {options.a_rows, options.n, options.a_cols}, + tensor_a.device_ref(), + tensor_b.device_ref(), + tensor_c.device_ref(), + tensor_d.device_ref(), + tensor_ell_idx.device_data(), + options.a_ell_num_columns, + options.a_ell_blocksize, + options.a_base, + epilogue_op + ); + + // Initialize the GEMM object + Gemm gemm{}; + + result.status = gemm.initialize(args); + + if (result.status != cutlass::Status::kSuccess) { + std::cerr << "Failed to initialize CUTLASS BlockedEll SpMM kernel." << std::endl; + return result; + } + + // Run the BlockedEll SpMM object + result.status = gemm.run(); + + if (result.status != cutlass::Status::kSuccess) { + std::cerr << "Failed to run CUTLASS BlockedEll SpMM kernel." << std::endl; + return result; + } + + // Wait for completion + result.error = cudaDeviceSynchronize(); + + if (result.error != cudaSuccess) { + std::cerr << "Kernel execution error: " << cudaGetErrorString(result.error); + return result; + } + + // + // Verify correctness + // + result.passed = true; + + if (options.reference_check) { + result.passed = verify_(); + } + + // + // Warm-up run + // + result.status = gemm.run(); + + if (result.status != cutlass::Status::kSuccess) { + std::cerr << "Failed to run CUTLASS BlockedEll SpMM kernel." << std::endl; + return result; + } + + // + // Construct events + // + + cudaEvent_t events[2]; + + for (auto & event : events) { + result.error = cudaEventCreate(&event); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; + return -1; + } + } + + // Record an event at the start of a series of GEMM operations + result.error = cudaEventRecord(events[0]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // + // Run profiling loop + // + + for (int iter = 0; iter < options.iterations; ++iter) { + gemm(); + } + + // + // Stop profiling loop + // + + // Record an event when the GEMM operations have been launched. + result.error = cudaEventRecord(events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Wait for work on the device to complete. + result.error = cudaEventSynchronize(events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Measure elapsed runtime + float runtime_ms = 0; + result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Compute average runtime and GFLOPs. + result.runtime_ms = double(runtime_ms) / double(options.iterations); + result.gflops = options.gflops(result.runtime_ms / 1000.0); + + // + // Cleanup + // + + for (auto event : events) { + (void)cudaEventDestroy(event); + } + + std::cout << std::endl; + std::cout << "ELL Block Sparse GEMM (CUTLASS):\n" + << "====================================================" << std::endl; + + std::cout << std::endl; + std::cout << " " << "Runtime: " << result.runtime_ms << " ms" << std::endl; + std::cout << " " << " GFLOPs: " << result.gflops << std::endl; + + return result; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // + // This example uses mma.sync to directly access Tensor Cores to achieve peak performance. + // + + cudaDeviceProp props; + + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if (__CUDACC_VER_MAJOR__ < 11 || props.major < 8) { + + // + // This example requires an NVIDIA Ampere-architecture GPU. + // + + std::cout + << "CUTLASS's BlockedEll SpMM example requires a GPU of NVIDIA's Ampere Architecture or " + << "later (compute capability 80 or greater).\n"; + + return 0; + } + + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Define the BlockedEll type + // + + using ElementA = cutlass::half_t; + using ElementB = cutlass::half_t; + using ElementOutput = cutlass::half_t; + using ElementAccumulator = float; + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + + constexpr int32_t kAlignmentA = 128 / cutlass::sizeof_bits::value; + constexpr int32_t kAlignmentB = 128 / cutlass::sizeof_bits::value; + + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + constexpr int32_t kStages = 4; + using Gemm = typename cutlass::gemm::device::EllGemm< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementOutput, + LayoutC, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + ThreadblockShape, + WarpShape, + InstructionShape, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<8>, + kStages, kAlignmentA, kAlignmentB>; + + // + // Profile it + // + + Testbed testbed(options); + + if (!testbed.sufficient()) { + std::cout << "The active CUDA device lacks sufficient hardware resources to execute this kernel.\n"; + return 0; + } + + Result result = testbed.profile(); + if (!result.passed) { + std::cout << "Profiling CUTLASS ELL block sparse GEMM has failed.\n"; + std::cout << "\nFailed\n"; + return -1; + } + + std::cout << "\nPassed\n"; + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/44_multi_gemm_ir_and_codegen/README.md b/examples/44_multi_gemm_ir_and_codegen/README.md new file mode 100644 index 0000000000..fd1839c5ad --- /dev/null +++ b/examples/44_multi_gemm_ir_and_codegen/README.md @@ -0,0 +1,63 @@ +This example provides utilities for generating back-to-back (B2B) GEMMs using CUTLASS. + +## Quick start +A configuration file containing the GEMMs to be fused together is located in [config.json](config.json). Edit +this to change the configuration that you would like to run. +```shell +cd ir_gen + +# Set up basic variables +out_dir=directory_to_emit_files +cutlass_dir=$(pwd)/../../.. +config_file=$(pwd)/../config.json + +# Generate code for GEMMs described in `config_file` +./generate.sh $config_file $out_dir $cutlass_dir + +# Build the generated code +cd $out_dir +mkdir build && cd build +cmake .. -DGPU_ARCHS="75;80" +make -j + +# Run the generated code with M=1024 K0=32 and Batch=1 +./sample 1024 32 1 +``` + +## Current restrictions +This experimental example has the following restrictions: +1. N tile should not exceed 256, or register spilling will occur. +2. Only FP16 is supported currently +3. Matrix A must be row major, matrix B must be column major, matrices C and D must be row major. + +## Copyright + +Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +SPDX-License-Identifier: BSD-3-Clause + +``` + Redistribution and use in source and binary forms, with or without + modification, are permitted provided that the following conditions are met: + + 1. Redistributions of source code must retain the above copyright notice, this + list of conditions and the following disclaimer. + + 2. Redistributions in binary form must reproduce the above copyright notice, + this list of conditions and the following disclaimer in the documentation + and/or other materials provided with the distribution. + + 3. Neither the name of the copyright holder nor the names of its + contributors may be used to endorse or promote products derived from + this software without specific prior written permission. + + THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +``` diff --git a/examples/44_multi_gemm_ir_and_codegen/config.json b/examples/44_multi_gemm_ir_and_codegen/config.json new file mode 100644 index 0000000000..bb8757c092 --- /dev/null +++ b/examples/44_multi_gemm_ir_and_codegen/config.json @@ -0,0 +1,32 @@ +{ + "0": { + "A_tp": "fp16", "B_tp": "fp16", "C_tp": "fp16", "Acc_tp": "fp16", + "A_format": "Row", "B_format": "Col", "C_format": "Row", + "mnk": [15000, 256, 32], + "epilogue": { + "tp": "LeakyRelu", + "bias": {"addbias": false, "bias_tp": "mat"}, + "args": [["float", "leaky_alpha", 1.3]] + } + }, + "1": { + "A_tp": "fp16", "B_tp": "fp16", "C_tp": "fp16", "Acc_tp": "fp16", + "A_format": "Row", "B_format": "Col", "C_format": "Row", + "mnk": [15000, 128, 256], + "epilogue": { + "tp": "LeakyRelu", + "bias": {"addbias": false, "bias_tp": "mat"}, + "args": [["float", "leaky_alpha", 1.3]] + } + }, + "2": { + "A_tp": "fp16", "B_tp": "fp16", "C_tp": "fp16", "Acc_tp": "fp16", + "A_format": "Row", "B_format": "Col", "C_format": "Row", + "mnk": [15000, 64, 128], + "epilogue": { + "tp": "LeakyRelu", + "bias": {"addbias": false, "bias_tp": "mat"}, + "args": [["float", "leaky_alpha", 1.3]] + } + } +} diff --git a/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/default_bias_act_epilogue_tensor_op.h b/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/default_bias_act_epilogue_tensor_op.h new file mode 100644 index 0000000000..2535e28e22 --- /dev/null +++ b/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/default_bias_act_epilogue_tensor_op.h @@ -0,0 +1,154 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + The epilogue rearranges the result of a matrix product through shared memory to match canonical + tensor layouts in global memory. Epilogues support conversion and reduction operations. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" + +#include "cutlass/gemm/gemm.h" + +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/linear_combination_clamp.h" +#include "cutlass/epilogue/thread/conversion_op.h" +#include "cutlass/epilogue/thread/reduction_op.h" + +#include "cutlass/transform/threadblock/regular_tile_iterator_pitch_linear.h" + +#include "cutlass/epilogue/warp/fragment_iterator_tensor_op.h" +#include "cutlass/epilogue/warp/fragment_iterator_complex_tensor_op.h" +#include "cutlass/epilogue/warp/tile_iterator_tensor_op.h" +#include "cutlass/epilogue/warp/tile_iterator_tensor_op_mixed.h" +#include "cutlass/epilogue/threadblock/default_thread_map_tensor_op.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" +#include "cutlass/epilogue/threadblock/shared_load_iterator.h" +#include "cutlass/epilogue/threadblock/shared_load_iterator_mixed.h" + +// #include "cutlass/epilogue/threadblock/epilogue.h" +#include "cutlass/epilogue/threadblock/interleaved_epilogue.h" + +#include "fused_bias_act_epilogue.h" +#include "../warp/fused_bias_act_fragment_iterator_tensor_op.h" +#include "output_tile_thread_map_for_fused_bias.h" +#include "default_thread_map_tensor_op_for_fused_bias.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + + +//////////////////////////////////////////////////////////////////////////////// + +/// Defines sensible defaults for epilogues for TensorOps. +template < + typename Shape_, + typename WarpMmaTensorOp_, + int PartitionsK, + typename OutputOp_, + int ElementsPerAccess +> +struct DefaultFusedBiasActEpilogueTensorOp { + + using Shape = Shape_; + using WarpMmaTensorOp = WarpMmaTensorOp_; + static int const kPartitionsK = PartitionsK; + using OutputOp = OutputOp_; + static int const kElementsPerAccess = ElementsPerAccess; + using ElementOutput = typename OutputOp::ElementOutput; + using LayoutC = typename WarpMmaTensorOp::LayoutC; + using ElementAccumulator = typename WarpMmaTensorOp::ElementC; + + // + // Thread map + // + + using OutputTileThreadMap = typename cutlass::epilogue::threadblock::DefaultThreadMapTensorOpForFusedBias< + Shape, + typename WarpMmaTensorOp::Shape, + kPartitionsK, + ElementOutput, + kElementsPerAccess + >::Type; + + using OutputTileIterator = cutlass::epilogue::threadblock::PredicatedTileIterator< + OutputTileThreadMap, + ElementOutput + >; + + using AccumulatorFragmentIterator = typename std::conditional::value, + cutlass::epilogue::warp::FragmentIteratorComplexTensorOp< + typename WarpMmaTensorOp::Shape, + typename WarpMmaTensorOp::Policy::Operator::Shape, + typename WarpMmaTensorOp::Policy::Operator::ElementC, + typename WarpMmaTensorOp::Policy::Operator::FragmentC, + LayoutC>, + cutlass::epilogue::warp::FusedBiasActFragmentIteratorTensorOp< + typename WarpMmaTensorOp::Shape, + typename WarpMmaTensorOp::Policy::Operator::Shape, + typename WarpMmaTensorOp::Policy::Operator::ElementC, + typename WarpMmaTensorOp::Policy::Operator::FragmentC, + LayoutC> >::type; + + // + // Define the epilogue + // + using Epilogue = cutlass::epilogue::threadblock::FusedBiasActEpilogue< + Shape, + WarpMmaTensorOp, + kPartitionsK, + OutputTileIterator, + AccumulatorFragmentIterator, + OutputOp + >; +}; + +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/default_thread_map_tensor_op_for_fused_bias.h b/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/default_thread_map_tensor_op_for_fused_bias.h new file mode 100644 index 0000000000..22f8e282a3 --- /dev/null +++ b/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/default_thread_map_tensor_op_for_fused_bias.h @@ -0,0 +1,113 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + +*/ + +#pragma once + +#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/layout/pitch_linear.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Defines the optimal thread map for TensorOp accumulator layouts +template < + typename ThreadblockShape_, + typename WarpShape_, + int PartitionsK, + typename Element_, + int ElementsPerAccess +> +struct DefaultThreadMapTensorOpForFusedBias { + + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + static int const kPartitionsK = PartitionsK; + using Element = Element_; + static int const kElementsPerAccess = ElementsPerAccess; + + // + // Definitions + // + + struct Detail { + + /// Tensor Operations fundamentally perform operations on 8 rows + static int const kTensorOpRows = 8; + static int const kWarpSize = 32; + + static_assert( + !(ThreadblockShape::kM % WarpShape::kM) && + !(ThreadblockShape::kM % WarpShape::kM), "Divisibility"); + + /// Number of warps + using WarpCount = gemm::GemmShape< + ThreadblockShape::kM / WarpShape::kM, + ThreadblockShape::kN / WarpShape::kN, + kPartitionsK + >; + + /// Number of participating threads + static int const kThreads = WarpCount::kCount * kWarpSize; + }; + + // + // ThreadMap + // + + /// ThreadMap to be used by epilogue::PredicatedTileIterator satisfying concept OutputTileThreadMap + using Type = OutputTileOptimalThreadMapBiasAct < + OutputTileShape, + OutputTileShape<1, WarpShape::kM / Detail::kTensorOpRows, 1, 1, WarpShape::kM / Detail::kTensorOpRows>, + Detail::kThreads, + kElementsPerAccess, + sizeof_bits::value + >; +}; + +/////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/fused_bias_act_epilogue.h b/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/fused_bias_act_epilogue.h new file mode 100644 index 0000000000..1acb4a2de6 --- /dev/null +++ b/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/fused_bias_act_epilogue.h @@ -0,0 +1,215 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + The epilogue rearranges the result of a matrix product through shared memory to match canonical + tensor layouts in global memory. Epilogues support conversion and reduction operations. + +*/ + +#pragma once + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/layout/vector.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/aligned_buffer.h" +#include "cutlass/functional.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/transform/threadblock/regular_tile_iterator.h" +#include "cutlass/epilogue/threadblock/epilogue_base.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Epilogue operator without splitk +template < + typename Shape_, ///< Shape of threadblock tile (concept: GemmShape) + typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp) + int PartitionsK, ///< Number of partitions of the K dimension + typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors + typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators + typename OutputOp_ ///< Output operator +> +class FusedBiasActEpilogue { + +public: + + using Shape = Shape_; + using WarpMmaOperator = WarpMmaOperator_; + static int const kPartitionsK = PartitionsK; + using OutputTileIterator = OutputTileIterator_; + using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; + using OutputOp = OutputOp_; + + /// Output layout is always row-major + using Layout = layout::RowMajor; + using LongIndex = typename Layout::LongIndex; + + /// The complete warp-level accumulator tile + using AccumulatorTile = typename AccumulatorFragmentIterator::AccumulatorTile; + + /// Output element + using ElementOutput = typename OutputTileIterator::Element; + + /// Output access size + static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; + + +public: + + + static_assert(OutputTileIterator::kElementsPerAccess, "OutputTileIterator::kElementsPerAccess must not be zero."); + + static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess), + "Divisibility"); + +public: + + /// Constructor + CUTLASS_DEVICE + FusedBiasActEpilogue( + ){ } + + /// Streams the result to global memory + CUTLASS_DEVICE + void operator()( + OutputOp const &output_op, ///< Output operator + AccumulatorTile &accumulators, ///< Complete warp-level accumulator tile + AccumulatorTile & fused_bias_act_accumlators, + OutputTileIterator source_iterator) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) + + bool need_bias = output_op.is_source_needed(); + + if (need_bias) + compute_source_needed_(output_op, accumulators, fused_bias_act_accumlators, source_iterator); + else + compute_source_no_needed_(output_op, accumulators, fused_bias_act_accumlators); + + + } + + CUTLASS_DEVICE + void operator()( + OutputOp const &output_op, ///< Output operator + AccumulatorTile &accumulators, ///< Complete warp-level accumulator tile + AccumulatorTile & fused_bias_act_accumlators) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) + + compute_source_no_needed_(output_op, accumulators, fused_bias_act_accumlators); + } + + CUTLASS_DEVICE + void compute_source_needed_( + OutputOp const &output_op, ///< Output operator + AccumulatorTile &accumulators, ///< Complete warp-level accumulator tile + AccumulatorTile & fused_bias_act_accumlators, + OutputTileIterator source_iterator) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) + + typename OutputTileIterator::Fragment source_fragment; + + + source_fragment.clear(); + + AccumulatorFragmentIterator accum_fragment_iterator(accumulators); + AccumulatorFragmentIterator fused_bias_act_fragment_iterator(fused_bias_act_accumlators); + + CUTLASS_PRAGMA_UNROLL + for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { + + source_iterator.load(source_fragment); + ++source_iterator; + + typename AccumulatorFragmentIterator::Fragment accum_fragment; + + accum_fragment_iterator.load(accum_fragment); + ++accum_fragment_iterator; + + typename AccumulatorFragmentIterator::Fragment fused_bias_act_fragment; + fused_bias_act_fragment = output_op(accum_fragment, source_fragment); + + fused_bias_act_fragment_iterator.store(fused_bias_act_fragment); + ++fused_bias_act_fragment_iterator; + } + } + + CUTLASS_DEVICE + void compute_source_no_needed_( + OutputOp const &output_op, ///< Output operator + AccumulatorTile &accumulators, ///< Complete warp-level accumulator tile + AccumulatorTile & fused_bias_act_accumlators) { ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) + + + AccumulatorFragmentIterator accum_fragment_iterator(accumulators); + AccumulatorFragmentIterator fused_bias_act_fragment_iterator(fused_bias_act_accumlators); + + + + CUTLASS_PRAGMA_UNROLL + for (int iter = 0; iter < AccumulatorFragmentIterator::kIterations; ++iter) { + + typename AccumulatorFragmentIterator::Fragment accum_fragment; + + accum_fragment_iterator.load(accum_fragment); + ++accum_fragment_iterator; + + typename AccumulatorFragmentIterator::Fragment fused_bias_act_fragment; + fused_bias_act_fragment = output_op(accum_fragment); + + fused_bias_act_fragment_iterator.store(fused_bias_act_fragment); + ++fused_bias_act_fragment_iterator; + } + } + +}; + + + + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/output_tile_thread_map_for_fused_bias.h b/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/output_tile_thread_map_for_fused_bias.h new file mode 100644 index 0000000000..c39e8ce199 --- /dev/null +++ b/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/threadblock/output_tile_thread_map_for_fused_bias.h @@ -0,0 +1,311 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Metaprogram for determining the mapping of output elements to threads for epilogue tiles. + + +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/fast_math.h" + +#include "cutlass/epilogue/threadblock/output_tile_thread_map.h" +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +/// RowArrangement determines how one or more warps cover a region of consecutive rows. +template < + typename Shape, + int WarpsRemaining, + int ElementsPerAccess, + int ElementSize, + bool Is2dTile +> +struct RowArrangementBiasAct; + +/// RowArrangement in which each warp's access is a 1D tiled arrangement. +template < + typename Shape, + int WarpsRemaining, + int ElementsPerAccess, + int ElementSize +> +struct RowArrangementBiasAct { + static int const kWarpSize = 32; + static int const kElementsPerAccess = ElementsPerAccess; + static int const kElementSize = ElementSize; + + static int const kIterationsRow = 1; + static int const kDeltaRow = 1; + static int const kIterationsColumn = Shape::kColumn / kElementsPerAccess / kWarpSize; + static int const kDeltaColumn = kWarpSize * kElementsPerAccess; + + static int const kAccessWidth = kWarpSize; + static int const kAccessRows = 1; + static int const kWarpPartitionsRow = 1; + static int const kWarpPartitionsColumn = WarpsRemaining; +}; + +/// RowArrangement in which each warp's access is a 2D tiled arrangement. +template < + typename Shape, + int WarpsRemaining, + int ElementsPerAccess, + int ElementSize +> +struct RowArrangementBiasAct { + + static int const kMemoryAccessSize = 4;//128; + static int const kWarpSize = 32; + + static int const kElementsPerAccess = ElementsPerAccess; + static int const kElementSize = ElementSize; + + struct Detail { + static int const kShapeRow = Shape::kRow / WarpsRemaining; + static int const kShapeWidth = Shape::kColumn / kElementsPerAccess; + + static int const kTargetMemoryAccessWidth = + kMemoryAccessSize / (kElementsPerAccess * kElementSize / 8); + + static int const kTargetAccessRows = kWarpSize / kTargetMemoryAccessWidth; + }; + + static int const kAccessWidth = + (Detail::kTargetAccessRows > Detail::kShapeRow ? + kWarpSize / Detail::kShapeRow + : const_min( + Detail::kShapeWidth, + const_min(kWarpSize, kMemoryAccessSize / (kElementsPerAccess * kElementSize / 8)) + )); + + static int const kAccessRows = + (Detail::kTargetAccessRows > Detail::kShapeRow ? + Detail::kShapeRow + : const_min(Shape::kRow, kWarpSize / kAccessWidth)); + + static int const kIterationsRow = Detail::kShapeRow / kAccessRows; + static int const kDeltaRow = kAccessRows; + + static int const kIterationsColumn = Detail::kShapeWidth / kAccessWidth; + static int const kDeltaColumn = kAccessWidth * kElementsPerAccess; + + static_assert( kAccessWidth * kElementsPerAccess <= Shape::kColumn, "Accessing too many elements per access"); + static_assert( kIterationsColumn > 0, "Iteration Count Column must be > 0" ); + static_assert( kIterationsRow > 0, "Iteration Count Row must be > 0" ); + + static int const kWarpPartitionsRow = 1; + static int const kWarpPartitionsColumn = 1; +}; + +} + +//////////////////////////////////////////////////////////////////////////////// + +/// Template metaprogram for partitioning a 4D space across warps to achieve several performance +/// objectives: +/// +/// - coalesced memory accesses in units of 16 Byte lines +/// - minimal address arithmetic +/// - minimal predicate calculations +/// +template < + typename Shape_, + typename Count_, + int Threads, + int ElementsPerAccess, + int ElementSize +> +struct OutputTileOptimalThreadMapBiasAct { + + using Shape = Shape_; + using Count = Count_; + + static int const kWarpSize = 32; + static int const kThreads = Threads; + static int const kWarpCount = kThreads / kWarpSize; + + static int const kElementsPerAccess = ElementsPerAccess; + static int const kElementSize = ElementSize; + + // + // Metaprogram computation + // + + struct Detail { + + // Clusters + static int const kIterationsCluster = + ((Shape::kCluster > kWarpCount) ? + Shape::kCluster / kWarpCount + : 1); + + static int const kDeltaCluster = + ((Shape::kCluster > kWarpCount) ? + Shape::kRow * Count::kRow * Shape::kGroup * Count::kGroup * Shape::kCluster / kIterationsCluster + : 1); + + static int const kCompactedDeltaCluster = + ((Shape::kCluster > kWarpCount) ? + Shape::kRow * Shape::kGroup * Shape::kCluster / kIterationsCluster + : 1); + + static int const kWarpPartitionsCluster = + ((Shape::kCluster > kWarpCount) ? + kWarpCount + : kWarpCount / Shape::kCluster); + + static int const kWarpsRemainingForGroups = + ((Shape::kCluster > kWarpCount) ? 1 : kWarpCount / Shape::kCluster); + + // Groups + static int const kIterationsGroup = + ((Shape::kGroup > kWarpsRemainingForGroups) ? + Shape::kGroup / kWarpsRemainingForGroups + : 1); + + static int const kDeltaGroup = + ((Shape::kGroup > kWarpsRemainingForGroups) ? + Shape::kRow * Count::kRow * Shape::kGroup / kIterationsGroup + : 1); + + static int const kCompactedDeltaGroup = + ((Shape::kGroup > kWarpsRemainingForGroups) ? + Shape::kRow * Shape::kGroup / kIterationsGroup + : 1); + + static int const kWarpPartitionsGroup = + ((Shape::kGroup > kWarpsRemainingForGroups) ? + 1 + : kWarpsRemainingForGroups / Shape::kGroup); + + static int const kWarpsRemainingForRows = + ((Shape::kGroup > kWarpsRemainingForGroups) ? + 1 + : kWarpsRemainingForGroups / Shape::kGroup); + + // Rows + using RowArrangement = detail::RowArrangementBiasAct< + Shape, + kWarpsRemainingForRows, + kElementsPerAccess, + kElementSize, + (Shape::kRow > kWarpsRemainingForRows) + >; + + // Warp partitions + using WarpPartitions = OutputTileShape< + RowArrangement::kWarpPartitionsColumn, + RowArrangement::kWarpPartitionsRow, + kWarpPartitionsGroup, + kWarpPartitionsCluster, + 1>; + + static int const kAccessWidth = RowArrangement::kAccessWidth; + static int const kAccessRows = RowArrangement::kAccessRows; + }; + + // + // Output + // + + using Iterations = OutputTileShape< + Detail::RowArrangement::kIterationsColumn, + Detail::RowArrangement::kIterationsRow, + Detail::kIterationsGroup, + Detail::kIterationsCluster, + 1>; + + using Delta = OutputTileShape< + Detail::RowArrangement::kDeltaColumn, + Detail::RowArrangement::kDeltaRow, + Detail::kDeltaGroup, + Detail::kDeltaCluster, + 1>; + + /// Initial offset function + CUTLASS_HOST_DEVICE + static MatrixCoord initial_offset(int thread_idx) { + + int warp_idx = thread_idx / kWarpSize; + int lane_idx = thread_idx % kWarpSize; + + // Compute warp location + int cluster_idx = warp_idx / Detail::WarpPartitions::kCluster; + int residual_cluster = warp_idx % Detail::WarpPartitions::kCluster; + + int group_idx = residual_cluster / Detail::WarpPartitions::kGroup; + int residual_group = residual_cluster % Detail::WarpPartitions::kGroup; + + int row_idx = residual_group / Detail::WarpPartitions::kRow; + int col_idx = residual_group % Detail::WarpPartitions::kRow; + + // Compute per-lane offset + int lane_row_offset = lane_idx / Detail::kAccessWidth; + int lane_col_offset = lane_idx % Detail::kAccessWidth; + + // Compute coordinate in output space + int cluster_offset = cluster_idx * Shape::kRow * Count::kRow * Shape::kGroup * Count::kGroup; + int group_offset = group_idx * Shape::kRow * Count::kRow; + int row_offset = row_idx * Iterations::kRow * Detail::kAccessRows; + int column_offset = col_idx * Iterations::kColumn * Detail::kAccessWidth * kElementsPerAccess; + + return MatrixCoord( + cluster_offset + group_offset + row_offset + lane_row_offset, + (column_offset + lane_col_offset) * kElementsPerAccess + ); + } + +}; + + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass diff --git a/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/warp/fused_bias_act_fragment_iterator_tensor_op.h b/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/warp/fused_bias_act_fragment_iterator_tensor_op.h new file mode 100644 index 0000000000..cf12fef3b5 --- /dev/null +++ b/examples/44_multi_gemm_ir_and_codegen/fixed_impl/epilogue/warp/fused_bias_act_fragment_iterator_tensor_op.h @@ -0,0 +1,189 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief This defines a "fragment" iterator for visiting the fragments of an accumulator tile + that participate in one warp-level store operation. + + Typically, the accumulator tile is the largest single block of register-backed storage + within the kernel. Storing it to memory is best accomplished by partitioning it into + smaller tiles and storing these sequentially. + + Round trips through shared memory during the Epilogue phase require partitioning, as + shared memory capacity is typically insufficient for a threadblock's total accumulator + size. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/layout/matrix.h" + +#include "cutlass/epilogue/warp/tensor_op_policy.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace warp { + +//////////////////////////////////////////////////////////////////////////////// + +/// +template < + typename WarpShape, ///< shape of warp-level GEMM (concept: MatrixShape) + typename OperatorShape, ///< matrix multiply operation shape (concept: gemm::GemmShape) + typename OperatorElementC, ///< matrix multiply operation data type (concept: data type) + typename OperatorFragmentC, ///< matrix multiply operation fragment (concept: Array) + typename Layout ///< target shared memory layout +> +class FusedBiasActFragmentIteratorTensorOp; + +//////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization for row-major shared memory +template < + typename WarpShape_, ///< shape of the warp-level GEMM tile + typename OperatorShape_, ///< matrix multiply operation shape (concept: gemm::GemmShape) + typename OperatorElementC_, ///< matrix multiply operation data type (concept: data type) + typename OperatorFragmentC_ ///< matrix multiply operation fragment (concept: Array) +> +class FusedBiasActFragmentIteratorTensorOp { +public: + + using WarpShape = WarpShape_; + using OperatorShape = OperatorShape_; + using OperatorElementC = OperatorElementC_; + using OperatorFragmentC = OperatorFragmentC_; + using Layout = layout::RowMajor; + + using Policy = TensorOpPolicy; + + /// This is the fragment size produced by one access of the iterator. + using Fragment = Array< + OperatorElementC, + Policy::OperatorCount::kColumn * Policy::kElementsPerAccess>; + + /// This is the complete warp-level accumulator tile. + using AccumulatorTile = Array< + OperatorElementC, + OperatorFragmentC::kElements * Policy::OperatorCount::kRow * Policy::OperatorCount::kColumn>; + + using OutputAccumulatorTile = AccumulatorTile; + + /// Number of times this iterator can be incremented + static int const kIterations = Policy::kIterations; + +private: + + /// Internal access type + using AccessType = Array; + +private: + + // + // Data members + // + + /// Accumulator tile + AccessType *accumulators_; + + /// Internal index + int index_; + +public: + + /// Constructs an iterator + CUTLASS_HOST_DEVICE + FusedBiasActFragmentIteratorTensorOp(AccumulatorTile &accum): + accumulators_(reinterpret_cast(&accum)), + index_(0) { + } + + /// Increments + CUTLASS_HOST_DEVICE + FusedBiasActFragmentIteratorTensorOp &operator++() { + ++index_; + return *this; + } + + /// Decrements + CUTLASS_HOST_DEVICE + FusedBiasActFragmentIteratorTensorOp &operator--() { + --index_; + return *this; + } + + /// Loads a fragment from the referenced part of the accumulator tile + CUTLASS_HOST_DEVICE + void load(Fragment &frag, int index_offset = 0) const { + + int index = index_ + index_offset; + + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Policy::OperatorCount::kColumn; ++n) { + + int accumulator_access_offset = + index + n * Policy::kAccumulatorColumnStride / Policy::kElementsPerAccess; + + frag_ptr[n] = accumulators_[accumulator_access_offset]; + } + } + /// Stores a fragment from the referenced part of the accumulator tile + CUTLASS_HOST_DEVICE + void store(Fragment &frag, int index_offset = 0) const { + + int index = index_ + index_offset; + + AccessType *frag_ptr = reinterpret_cast(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Policy::OperatorCount::kColumn; ++n) { + + int accumulator_access_offset = + index + n * Policy::kAccumulatorColumnStride / Policy::kElementsPerAccess; + + accumulators_[accumulator_access_offset] = frag_ptr[n]; + } + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/examples/44_multi_gemm_ir_and_codegen/fixed_impl/gemm/warp/mma_tensor_op_fragment_iterator_without_output_op.h b/examples/44_multi_gemm_ir_and_codegen/fixed_impl/gemm/warp/mma_tensor_op_fragment_iterator_without_output_op.h new file mode 100644 index 0000000000..0e89d6f875 --- /dev/null +++ b/examples/44_multi_gemm_ir_and_codegen/fixed_impl/gemm/warp/mma_tensor_op_fragment_iterator_without_output_op.h @@ -0,0 +1,427 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/array.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/numeric_conversion.h" + +namespace cutlass { +namespace gemm { +namespace warp { + + +//////////////////////////////////////////////////////////////////////////////// + +template < + /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Size of the accumulation tile shape (concept: MatrixShape) + typename AccumulatorShape_, + /// KBlocks columns to compute residual + int KBlocksColumn_, + /// Accumulator Element type + typename ElementAccumulator_, + /// Element type + typename Element_, + /// Layout of operand in memory + typename Layout_, + /// Shape of one matrix product operation (concept: MatrixShape) + typename InstructionShape_, + /// Whether beta is zero + bool IsBetaZero_ > +class MmaTensorOpPureFragmentIterator; + + +// Partial specialization for col-major accumulator tile +// And Element type is the same as Accumulator Element type + +template < + /// Shape of warp tile to load (concept: MatrixShape) + typename Shape_, + /// Shape of the warp accumulation tile (concept: MatrixShape) + typename AccumulatorShape_, + /// KBlocks columns to compute residual + int KBlocksColumn_, + /// Element type + typename Element_, + /// Shape of one matrix product operation (concept: MatrixShape) + typename InstructionShape_> +class MmaTensorOpPureFragmentIterator { + public: + + /// Shape of warp tile to load (concept: MatrixShape) + using Shape = Shape_; + + /// Shape of the warp accumulation tile (concept: MatrixShape) + using AccumulatorShape = AccumulatorShape_; + + /// KBlocks columns to compute residual + static int const kKBlockColumn = KBlocksColumn_; + + /// Element type + using Element = Element_; + + /// Layout of source tile + using Layout = cutlass::layout::ColumnMajor; + + /// Shape of one matrix product operation (concept: MatrixShape) + using InstructionShape = InstructionShape_; + + /// Whether beta is zero + static bool const IsBetaZero = true; + + /// Number of participating threads + static int const kThreads = 32; + + /// Internal structure of iterator - made public to enable introspection + struct Policy { + static_assert( + !(Shape::kRow % InstructionShape::kM) && + !(Shape::kColumn % InstructionShape::kN), + "Shape of warp-level Mma must be divisible by operator shape."); + static_assert( + !(AccumulatorShape::kRow % Shape::kRow) && + !(AccumulatorShape::kColumn % Shape::kColumn), + "Shape of Warp Accumulator must be divisible by warp shape."); + static_assert( + !(kKBlockColumn % Shape::kColumn), + "KBlock size must be divisible by warp shape."); + + /// Number of times this iterator can be incremented + static int const kIterations = AccumulatorShape::kCount / Shape::kCount; + }; + +private: + + static int const kElementsPerAccess = InstructionShape::kM * InstructionShape::kN / kThreads; + + /// Number of mma operations performed by a warp + using MmaIterations = MatrixShape; + /// Number of mma operations performed by the entire accumulator + using AccumulatorIterations = MatrixShape; + + /// Number of K iterations + static int const kKBlockIterations = (AccumulatorShape::kColumn + kKBlockColumn - 1) / kKBlockColumn; + static int const kResidualColumn = AccumulatorShape::kColumn - (kKBlockIterations - 1) * kKBlockColumn; + static int const kKBlockColumnIterations = kKBlockColumn / Shape::kColumn + * (AccumulatorShape::kRow / Shape::kRow); + static int const kResidualIndex = kResidualColumn / Shape::kColumn + * (AccumulatorShape::kRow / Shape::kRow); + +public: + + // + // Derived quantities + // + + /// Fragment object holding a thread's part of a tile + /// This is the fragment size produced by one access of the iterator. + using Fragment = Array; + + /// Accumulator Fragment object + using AccumulatorFragment = Array; + + +private: + + /// Internal access type + using AccessType = Array; + +private: + // + // Data members + // + + /// Accumulator tile + AccessType const *accumulators_; + + /// Internal index + int index_; + + /// Used to access residual tile first + bool is_residual_tile_; + +public: + /// Constructs an iterator + CUTLASS_HOST_DEVICE + MmaTensorOpPureFragmentIterator(AccumulatorFragment const &accum) + : accumulators_(reinterpret_cast(&accum)), + index_(0), is_residual_tile_(true) {} + + /// Add offset + CUTLASS_HOST_DEVICE + void add_offset(int index_offset) { + index_ += index_offset; + if(is_residual_tile_ && index_ >= kKBlockColumnIterations) { + index_ = index_ - kKBlockColumnIterations + kResidualIndex; + is_residual_tile_ = false; + } + } + + /// Increments + CUTLASS_HOST_DEVICE + MmaTensorOpPureFragmentIterator &operator++() { + add_offset(1); + return *this; + } + + /// Decrements + CUTLASS_HOST_DEVICE + MmaTensorOpPureFragmentIterator &operator--() { + add_offset(-1); + return *this; + } + + /// Loads a fragment from the referenced part of the accumulator tile + CUTLASS_HOST_DEVICE + void load(Fragment &frag) const { + + AccessType src_fragment; + src_fragment.clear(); + + + AccessType *frag_ptr = reinterpret_cast(&frag); + + int index_m = (index_ * MmaIterations::kRow) % AccumulatorIterations::kRow; + int index_n = (index_ * MmaIterations::kRow) / AccumulatorIterations::kRow + * MmaIterations::kColumn; + + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < MmaIterations::kColumn; n++) { + for (int m = 0; m < MmaIterations::kRow; m++) { + int accumulator_access_offset = + (n + index_n) * AccumulatorIterations::kRow + m + index_m; + + frag_ptr[n * MmaIterations::kRow + m].clear(); + if(!(is_residual_tile_ && index_ >= kResidualIndex)) + frag_ptr[n * MmaIterations::kRow + m] = accumulators_[accumulator_access_offset]; + // frag_ptr[n * MmaIterations::kRow + m] = output_op(accumulators_[accumulator_access_offset], src_fragment); + } + } + } + +}; + +// Partial specialization for row-major accumulator tile + +template < + /// Shape of warp tile to load (concept: MatrixShape) + typename Shape_, + /// Shape of the warp accumulation tile (concept: MatrixShape) + typename AccumulatorShape_, + /// KBlocks columns to compute residual + int KBlocksColumn_, + /// Accumulator Element type + typename ElementAccumulator_, + /// Element type + typename Element_, + /// Shape of one matrix product operation (concept: MatrixShape) + typename InstructionShape_> +class MmaTensorOpPureFragmentIterator { + public: + + /// Shape of warp tile to load (concept: MatrixShape) + using Shape = Shape_; + + /// Shape of the warp accumulation tile (concept: MatrixShape) + using AccumulatorShape = AccumulatorShape_; + + /// KBlocks columns to compute residual + static int const kKBlockColumn = KBlocksColumn_; + + /// Accumulator Element type + using ElementAccumulator = ElementAccumulator_; + + /// Element type + using Element = Element_; + + /// Layout of source tile + using Layout = cutlass::layout::RowMajor; + + /// Shape of one matrix product operation (concept: MatrixShape) + using InstructionShape = InstructionShape_; + + /// Whether beta is zero + static bool const IsBetaZero = true; + + /// Number of participating threads + static int const kThreads = 32; + + /// Internal structure of iterator - made public to enable introspection + struct Policy { + static_assert( + !(Shape::kRow % InstructionShape::kM) && + !(Shape::kColumn % InstructionShape::kN), + "Shape of warp-level Mma must be divisible by operator shape."); + static_assert( + !(AccumulatorShape::kRow % Shape::kRow) && + !(AccumulatorShape::kColumn % Shape::kColumn), + "Shape of Warp Accumulator must be divisible by warp shape."); + static_assert( + !(kKBlockColumn % Shape::kColumn), + "KBlock size must be divisible by warp shape."); + + /// Number of times this iterator can be incremented + static int const kIterations = AccumulatorShape::kCount / Shape::kCount; + }; + +private: + + static int const kElementsPerAccess = InstructionShape::kM * InstructionShape::kN / kThreads; + + /// Number of mma operations performed by a warp + using MmaIterations = MatrixShape; + /// Number of mma operations performed by the entire accumulator + using AccumulatorIterations = MatrixShape; + + /// Number of K iterations + static int const kKBlockIterations = (AccumulatorShape::kColumn + kKBlockColumn - 1) / kKBlockColumn; + static int const kResidualColumn = AccumulatorShape::kColumn - (kKBlockIterations - 1) * kKBlockColumn; + static int const kKBlockColumnIterations = kKBlockColumn / Shape::kColumn + * (AccumulatorShape::kRow / Shape::kRow); + static int const kResidualIndex = kResidualColumn / Shape::kColumn + * (AccumulatorShape::kRow / Shape::kRow); + +public: + + // + // Derived quantities + // + + /// Fragment object holding a thread's part of a tile + /// This is the fragment size produced by one access of the iterator. + using Fragment = Array; + + /// Accumulator Fragment object + using AccumulatorFragment = Array; + + +private: + + /// Internal access type + using AccessType = Array; + using FragmentAccessType = Array; + +private: + // + // Data members + // + + /// Accumulator tile + AccessType const *accumulators_; + + /// Internal index + int index_; + + /// Used to access residual tile first + bool is_residual_tile_; + +public: + /// Constructs an iterator + CUTLASS_HOST_DEVICE + MmaTensorOpPureFragmentIterator(AccumulatorFragment const &accum) + : accumulators_(reinterpret_cast(&accum)), + index_(0), is_residual_tile_(true) {} + + /// Add offset + CUTLASS_HOST_DEVICE + void add_offset(int index_offset) { + index_ += index_offset; + if(is_residual_tile_ && index_ >= kKBlockColumnIterations) { + index_ = index_ - kKBlockColumnIterations + kResidualIndex; + is_residual_tile_ = false; + } + } + + /// Increments + CUTLASS_HOST_DEVICE + MmaTensorOpPureFragmentIterator &operator++() { + add_offset(1); + return *this; + } + + /// Decrements + CUTLASS_HOST_DEVICE + MmaTensorOpPureFragmentIterator &operator--() { + add_offset(-1); + return *this; + } + + /// Loads a fragment from the referenced part of the accumulator tile + CUTLASS_HOST_DEVICE + void load(Fragment &frag) const { + + + FragmentAccessType src_fragment; + src_fragment.clear(); + + FragmentAccessType *frag_ptr = reinterpret_cast(&frag); + + int index_m = (index_ * MmaIterations::kRow) % AccumulatorIterations::kRow; + int index_n = (index_ * MmaIterations::kRow) / AccumulatorIterations::kRow + * MmaIterations::kColumn; + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < MmaIterations::kRow; m++) { + for (int n = 0; n < MmaIterations::kColumn; n++) { + int accumulator_access_offset = + (m + index_m) * AccumulatorIterations::kColumn + n + index_n; + + frag_ptr[m * MmaIterations::kColumn + n].clear(); + if(!(is_residual_tile_ && index_ >= kResidualIndex)) + frag_ptr[m * MmaIterations::kColumn + n] = (accumulators_[accumulator_access_offset]); + } + } + } + +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_all_code.py b/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_all_code.py new file mode 100644 index 0000000000..6aef3bca9c --- /dev/null +++ b/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_all_code.py @@ -0,0 +1,129 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import gen_turing_and_volta as api_generator +import gen_sample as sample_creater +import gen_cmake as cmake_creater +import gen_verify as verify_creater +import gen_device as b2b_fused_generator +import replace_fix_impl_header + +import argparse +import os +import json + + +parser = argparse.ArgumentParser(description="Generates Fused Multi-GEMM CUTLASS Kernels") +parser.add_argument("--config-file", default="config.json", help="JSON file containing configuration to generate") +parser.add_argument("--gen-name", default="FusedMultiGemmForward", help="Specific the output name") +parser.add_argument("--output-dir", default="", help="Specifies the output dir") +parser.add_argument("--cutlass-dir", default="", help="Specifies the dependent CUTLASS repo dir") +parser.add_argument("--gen-include-cutlass-dir", default="", help="Specifies the generated CUTLASS code include dir, if needed.") +args = parser.parse_args() + +gen_name = args.gen_name + +cutlass_deps_dir = args.cutlass_dir + +output_dir = args.output_dir +output_dir += "/" + +cutlass_deps_root = args.gen_include_cutlass_dir +if cutlass_deps_root == '': + cutlass_deps_root = cutlass_deps_dir + "/include/" +cutlass_deps_root +='/' + + +if not os.path.exists(output_dir): + os.makedirs(output_dir) + +if not os.path.exists(output_dir + "/" + "auto_gen"): + os.mkdir(output_dir + "/" + "auto_gen") + +if not os.path.exists(output_dir + "/" + "fixed_impl"): + os.mkdir(output_dir + "/" + "fixed_impl" ) + +if not os.path.exists(output_dir + "/" + "sample"): + os.mkdir(output_dir + "/" + "sample" ) + +if not os.path.exists(output_dir + "/" + "auto_gen" + "/" + "device"): + os.mkdir(output_dir + "/" + "auto_gen" + "/" + "device") +if not os.path.exists(output_dir + "/" + "auto_gen" + "/" + "kernel"): + os.mkdir(output_dir + "/" + "auto_gen" + "/" + "kernel") +if not os.path.exists(output_dir + "/" + "auto_gen" + "/" + "threadblock"): + os.mkdir(output_dir + "/" + "auto_gen" + "/" + "threadblock") + +with open(args.config_file, 'r') as infile: + gemm_info_dict = json.load(infile) + +keys = sorted(gemm_info_dict.keys()) +fuse_gemm_info = [gemm_info_dict[k] for k in keys] + + +for_cutlass_gen_user_include_header_file = [ + cutlass_deps_root + "cutlass/epilogue/thread/linear_combination_leaky_relu.h", + cutlass_deps_root + "cutlass/epilogue/thread/linear_combination.h", +] + +for_fused_wrapper = [ + cutlass_deps_root + "cutlass/epilogue/thread/linear_combination_leaky_relu.h", + cutlass_deps_root + "cutlass/epilogue/thread/linear_combination.h", + "auto_gen/device/" + gen_name + ".h", + cutlass_deps_root + "cutlass/gemm/device/gemm_batched.h", + cutlass_deps_root + "cutlass/cutlass.h", +] + +# Copy fixed implementation to the output directory +fix_impl = replace_fix_impl_header.replace_fix_impl("../fixed_impl/", output_dir +"/fixed_impl/", cutlass_deps_root) +fix_impl.gen_code() + +auto_gen_output_dir = output_dir + "/auto_gen/" +project_root = "" +turing_plus = b2b_fused_generator.gen_device(fuse_gemm_info, gen_name, for_cutlass_gen_user_include_header_file, cutlass_deps_root, project_root, auto_gen_output_dir) +turing_plus.gen_code(75, 'hmma1688', False) + +api = api_generator.gen_one_API(fuse_gemm_info, gen_name, for_fused_wrapper, output_dir) +api.gen_code() + +# Generate C++ sample +os.system("cp ../leaky_bias.h " + output_dir + "/sample/") +os.system("cp ../utils.h " + output_dir + "/sample/") + +sample_dir = output_dir + "/sample/" +sample = sample_creater.gen_test(fuse_gemm_info, gen_name, for_cutlass_gen_user_include_header_file, sample_dir) +sample.gen_cpp_sample() + +cmake_gen = cmake_creater.gen_build_sys(cutlass_deps_dir, output_dir) +cmake_gen.gen_code() + +verify = verify_creater.gen_verify(fuse_gemm_info, gen_name, for_fused_wrapper, output_dir) +verify.gen_code() diff --git a/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_cmake.py b/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_cmake.py new file mode 100644 index 0000000000..5db6dd6e07 --- /dev/null +++ b/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_cmake.py @@ -0,0 +1,131 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +class gen_build_sys: + def __init__(self, cutlass_deps_dir, output_dir = "../"): + self.output_dir = output_dir + self.cutlass_deps_dir = cutlass_deps_dir + + def gen_top(self): + code = "" + code += '''\ +# Auto Generated code - Do not edit. + +cmake_minimum_required(VERSION 3.8) +project(CUTLASS_MULTI_GEMMS LANGUAGES CXX CUDA) +find_package(CUDAToolkit) +set(CUDA_PATH ${{CUDA_TOOLKIT_ROOT_DIR}}) +set(CUTLASS_PATH \"{cutlass_deps_dir}/include\") +set(CUTLASS_UTIL_PATH \"{cutlass_deps_dir}/tools/util/include\") +list(APPEND CMAKE_MODULE_PATH ${{CUDAToolkit_LIBRARY_DIR}}) +'''.format(cutlass_deps_dir=self.cutlass_deps_dir) + + code += '''\ +set(GPU_ARCHS \"\" CACHE STRING + \"List of GPU architectures (semicolon-separated) to be compiled for.\") + +if(\"${GPU_ARCHS}\" STREQUAL \"\") + set(GPU_ARCHS \"70\") +endif() + +foreach(arch ${GPU_ARCHS}) + set(CMAKE_CUDA_FLAGS \"${CMAKE_CUDA_FLAGS} -gencode arch=compute_${arch},code=sm_${arch}\") + if(SM STREQUAL 70 OR SM STREQUAL 75) + set(CMAKE_C_FLAGS \"${CMAKE_C_FLAGS} -DWMMA\") + set(CMAKE_CXX_FLAGS \"${CMAKE_CXX_FLAGS} -DWMMA\") + set(CMAKE_CUDA_FLAGS \"${CMAKE_CUDA_FLAGS} -DWMMA\") + endif() +endforeach() + +set(CMAKE_C_FLAGS \"${CMAKE_C_FLAGS}\") +set(CMAKE_CXX_FLAGS \"${CMAKE_CXX_FLAGS}\") +set(CMAKE_CUDA_FLAGS \"${CMAKE_CUDA_FLAGS} -Xcompiler -Wall\") + +set(CMAKE_C_FLAGS_DEBUG \"${CMAKE_C_FLAGS_DEBUG} -Wall -O0\") +set(CMAKE_CXX_FLAGS_DEBUG \"${CMAKE_CXX_FLAGS_DEBUG} -Wall -O0\") +set(CMAKE_CUDA_FLAGS_DEBUG \"${CMAKE_CUDA_FLAGS_DEBUG} -O0 -G -Xcompiler -Wall\") + +set(CMAKE_CXX_STANDARD 11) +set(CMAKE_CXX_STANDARD_REQUIRED ON) + +if(CMAKE_CXX_STANDARD STREQUAL \"11\") + set(CMAKE_CUDA_FLAGS \"${CMAKE_CUDA_FLAGS} --expt-extended-lambda\") + set(CMAKE_CUDA_FLAGS \"${CMAKE_CUDA_FLAGS} --expt-relaxed-constexpr\") +endif() + +set(CMAKE_CXX_FLAGS \"${CMAKE_CXX_FLAGS} -g -O3\") +set(CMAKE_CUDA_FLAGS \"${CMAKE_CUDA_FLAGS} -Xcompiler -O3\") +set(CMAKE_CUDA_FLAGS \"${CMAKE_CUDA_FLAGS} -Xcompiler=-fno-strict-aliasing\") + +set(COMMON_HEADER_DIRS + ${PROJECT_SOURCE_DIR} + ${CUDAToolkit_INCLUDE_DIRS} +) + +set(COMMON_LIB_DIRS + ${CUDAToolkit_LIBRARY_DIR} +) +list(APPEND COMMON_HEADER_DIRS ${CUTLASS_PATH}) +list(APPEND COMMON_HEADER_DIRS ${CUTLASS_UTIL_PATH}) +''' + code += '''\ +include_directories( + ${COMMON_HEADER_DIRS} +) + +link_directories( + ${COMMON_LIB_DIRS} +) + +add_definitions(-D_GLIBCXX_USE_CXX11_ABI=0) +add_definitions(-DGOOGLE_CUDA=1) + +add_executable(sample + sample/sample.cu + one_api.cu +) +target_link_libraries(sample PRIVATE + -lcudart + -lnvToolsExt + ${CMAKE_THREAD_LIBS_INIT} +) + +if(NOT DEFINED LIB_INSTALL_PATH) + set(LIB_INSTALL_PATH ${CMAKE_CURRENT_BINARY_DIR}) +endif() +''' + return code + + def gen_code(self): + top_code = self.gen_top() + with open(self.output_dir + "CMakeLists.txt", "w") as f: + f.write(top_code) diff --git a/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_customized_epilogue.py b/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_customized_epilogue.py new file mode 100644 index 0000000000..84621f2e79 --- /dev/null +++ b/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_customized_epilogue.py @@ -0,0 +1,120 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import ast + +fuse_gemm_info = [ + { + 'epilogue': { + 'tp': 'LeakyRelu', #'CustomizedLeaky_RELU' + 'bias': {'addbias': False, 'bias_tp': 'mat'}, + 'args': [('float', 'leaky_alpha', 1.3), ], + 'func': ''' +y = max(leaky_alpha * x, x) +y = y * x + ''' + } + }, + +] +class AnalysisNodeVisitor(ast.NodeVisitor): + def visit_Import(self,node): + ast.NodeVisitor.generic_visit(self, node) + + def visit_ImportFrom(self,node): + ast.NodeVisitor.generic_visit(self, node) + + def visit_Assign(self,node): + print('Node type: Assign and fields: ', node._fields) + # print('Node type: Assign and targets value: ', node.targets, node.value) + + ast.NodeVisitor.generic_visit(self, node) + + def visit_BinOp(self, node): + print('Node type: BinOp and fields: ', node._fields) + print('node op: ', type(node.op).__name__) + ast.NodeVisitor.generic_visit(self, node) + + def visit_Expr(self, node): + print('Node type: Expr and fields: ', node._fields) + ast.NodeVisitor.generic_visit(self, node) + + def visit_Num(self,node): + print('Node type: Num and fields: ', node._fields) + print('Node type: Num: ', node.n) + + def visit_Name(self,node): + print('Node type: Name and fields: ', node._fields) + print('Node type: Name and fields: ', type(node.ctx).__name__, node.id) + + ast.NodeVisitor.generic_visit(self, node) + + def visit_Str(self, node): + print('Node type: Str and fields: ', node._fields) + +class CodeVisitor(ast.NodeVisitor): + def visit_BinOp(self, node): + if isinstance(node.op, ast.Add): + node.op = ast.Sub() + self.generic_visit(node) + + def visit_Assign(self, node): + print('Assign %s' % node.value) + self.generic_visit(node) + + def visit_Name(self, node): + print("Name:", node.id) + self.generic_visit(node) + + + def visit_FunctionDef(self, node): + print('Function Name:%s'% node.name.op) + self.generic_visit(node) + func_log_stmt = ast.Print( + dest = None, + values = [ast.Str(s = 'calling func: %s' % node.name, lineno = 0, col_offset = 0)], + nl = True, + lineno = 0, + col_offset = 0, + ) + node.body.insert(0, func_log_stmt) + +visitor = AnalysisNodeVisitor() + +code = \ +''' + +a=max(leaky_alpha * x, x +1) + +''' + +visitor.visit(ast.parse(code)) diff --git a/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_device.py b/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_device.py new file mode 100644 index 0000000000..371a4be847 --- /dev/null +++ b/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_device.py @@ -0,0 +1,469 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +from typing import * + +import helper +import gen_ir + +import gen_kernel as gen_ker + + +class gen_device: + def __init__(self, fuse_gemm_info, gen_class_name, user_header_file, cutlass_deps_root, project_root, output_dir = "../"): + self.fuse_gemm_info = fuse_gemm_info + self.raw_gemm_info = fuse_gemm_info + self.b2b_num = len(fuse_gemm_info) + self.user_header_file = user_header_file + self.args = {} + # device arg struct memebr + self.arg_member = [] + self.gen_class_name = gen_class_name + self.gen_kernel_name = gen_class_name + "Kernel" + self.template_args = [] + self.__tempalate_arg_list = {'Stages': int, 'SplitKSerial': bool, 'IsBetaZero': bool, 'AlignmentA': int, 'AlignmentB': int} + + self.file_name = output_dir + "/device/" +gen_class_name +".h" + self.sample_dir = output_dir + + + self.cutlass_deps_root = cutlass_deps_root + self.project_root = project_root + self.this_file_root = output_dir + "/device/" + + self.first_use_1stage = False + + ## gen kernel + self.gen_kernel = gen_ker.gen_kernel(self.template_args, self.gen_class_name, self.b2b_num, output_dir, cutlass_deps_root, project_root) + + + def __check_arg_type(self, temp_arg): + if temp_arg in self.__tempalate_arg_list.keys(): + return self.__tempalate_arg_list[temp_arg] + + find_sub = False + for candidate_arg in self.__tempalate_arg_list.keys(): + if (temp_arg.find(candidate_arg) != -1): + return self.__tempalate_arg_list[candidate_arg] + + return 'typename' + + # def gen_B2b2bGemm_class(): + def set_arch(self, sm_cap, mma_tp): + if sm_cap == 75 or sm_cap == 80 or sm_cap == 86: + self.arch = "cutlass::arch::Sm" + str(sm_cap) + + if mma_tp is 'hmma1688': + self.mma_shape = [16, 8, 8] + self.mma_tp = 'hmma' + elif mma_tp is 'imma8816': + self.mma_tp = 'imma' + self.mma_shape = [8, 8, 16] + else: + return 0 + + def gen_include_header(self): + code = '''\ +/* Auto Generated code - Do not edit.*/ + +#pragma once + +#include \"{cutlass_root}cutlass/cutlass.h\" +#include \"{cutlass_root}cutlass/numeric_types.h\" +#include \"{cutlass_root}cutlass/arch/arch.h\" +#include \"{cutlass_root}cutlass/device_kernel.h\" + +#include \"{cutlass_root}cutlass/gemm/threadblock/threadblock_swizzle.h\" + +#include \"{cutlass_root}cutlass/gemm/device/default_gemm_configuration.h\" +#include \"{cutlass_root}cutlass/epilogue/thread/linear_combination_relu.h\" +#include \"{cutlass_root}cutlass/epilogue/thread/linear_combination.h\" + +#include \"{project_root}../kernel/b2b_gemm.h\" +#include \"{project_root}../kernel/default_b2b_gemm.h\" +'''.format(cutlass_root=self.cutlass_deps_root, project_root=self.project_root, this_file_root=self.this_file_root) + include_user_header = "" + for header in self.user_header_file: + include_user_header += "#include \"" + header + "\"\n" + return code + include_user_header + + def gen_code(self, sm_cap, mma_tp, ifprint = True): + self.set_arch(sm_cap, mma_tp) + + self.update_b2b_args() + print(self.fuse_gemm_info) + self.update_b2b_class_template_args() + + func_code = self.gen_all_func() + member_var_code = "private:\n typename B2bGemmKernel::Params params_;\n" + + gen_code = gen_ir.gen_template_class(self.gen_class_name, self.template_args, func_code + member_var_code) + code = self.gen_include_header() + gen_ir.gen_namespace("cutlass", gen_ir.gen_namespace("gemm", gen_ir.gen_namespace("device", gen_code))) + + if ifprint: + print(code) + + print("[INFO]: Gen device code output Dir: is ", self.file_name) + with open(self.file_name, 'w+') as f: + f.write(code) + + + gen_kernel = self.gen_kernel.gen_code(self.first_use_1stage) + print(gen_kernel) + + def update_b2b_class_template_args(self): + for arg in self.args.keys(): + self.template_args.append([self.__check_arg_type(arg), arg, self.args[arg]]) + + def update_b2b_args(self): + + self.args['ElementA'] = helper.type_2_cutlass_type(self.fuse_gemm_info[0]['A_tp']) + self.args['LayoutA'] = helper.type_2_cutlass_type(self.fuse_gemm_info[0]['A_format']) + + cnt = 0 + + warp_M_tile = 32 + + # Determine maxmimum N_tile + Max_Ntile = 0 + for layer in self.fuse_gemm_info: + n_tile = layer['mnk'][1] + if n_tile > Max_Ntile: + Max_Ntile = n_tile + if Max_Ntile >= 256: + warp_M_tile = 16 + + stages_temp = [] + + for layer in self.fuse_gemm_info: + cnt_str = str(cnt) + B_tp_str= 'ElementB' + cnt_str + B_format_str = 'LayoutB' + cnt_str + C_tp_str= 'ElementC' + cnt_str + C_format_str = 'LayoutC' + cnt_str + Acc_str = 'ElementAccumulator' + cnt_str + + self.args[B_tp_str] = helper.type_2_cutlass_type(layer['B_tp']) + self.args[B_format_str] = helper.type_2_cutlass_type(layer['B_format']) + self.args[C_tp_str] = helper.type_2_cutlass_type(layer['C_tp']) + self.args[C_format_str] = helper.type_2_cutlass_type(layer['C_format']) + self.args[Acc_str] = helper.type_2_cutlass_type(layer['Acc_tp']) + + + mnk = layer['mnk'][:] + + tile_mnk = mnk[:] + + tile_mnk[2] = 32 # force the ktile is 32 + + #N tile gen + if mnk[1] > 1024: + assert(0) + elif mnk[1] > 512: + tile_mnk[1] = 1024 + elif mnk[1] > 256: + tile_mnk[1] = 512 + elif mnk[1] > 128: + tile_mnk[1] = 256 + elif mnk[1] > 64: + tile_mnk[1] = 128 + elif mnk[1] > 32: + tile_mnk[1] = 64 + else : + tile_mnk[1] = 32 + + if tile_mnk[1] == 512: + stages_temp.append(1) + else: + stages_temp.append(2) + + tile_mnk[0] = 4 * warp_M_tile + + + + epilogue_setted_type = helper.get_epilogue_tp(layer) + cutlass_epilogue_name = "LinearCombinationRelu" + if epilogue_setted_type.lower() == 'leakyrelu': + cutlass_epilogue_name = "LinearCombinationLeakyRelu" + elif epilogue_setted_type.lower() == 'identity': + cutlass_epilogue_name = "LinearCombination" + + epilogue_str = 'EpilogueOutputOp' + cnt_str + if cnt != len(self.fuse_gemm_info) - 1: + n = layer['mnk'][1] + Fragments = tile_mnk[1] // 8 * 2 + self.args[epilogue_str] = "cutlass::epilogue::thread::" + cutlass_epilogue_name + "" + else: + n = layer['mnk'][1] + n_mod_8 = n % 4 + N_align_elements = 1 + if n_mod_8 == 0: + N_align_elements = 8 + elif n_mod_8 == 4: + N_align_elements = 4 + elif n_mod_8 == 2 or n_mod_8 == 6: + N_align_elements = 2 + + self.args[epilogue_str] = "cutlass::epilogue::thread::" + cutlass_epilogue_name+ "" + + + + ThreadBlockShape_str = 'ThreadblockShape' + cnt_str + + self.args[ThreadBlockShape_str] = helper.cvt_2_cutlass_shape(tile_mnk) + + WarpShape_str = 'WarpShape' + cnt_str + tile_mnk[0] = warp_M_tile + self.args[WarpShape_str] = helper.cvt_2_cutlass_shape(tile_mnk) + cnt += 1 + + + self.args['ElementD'] = helper.type_2_cutlass_type(self.fuse_gemm_info[self.b2b_num - 1]['C_tp']) + self.args['LayoutD'] = helper.type_2_cutlass_type(self.fuse_gemm_info[self.b2b_num - 1]['C_format']) + + self.args['InstructionShape'] = helper.cvt_2_cutlass_shape(self.mma_shape) + self.args['OperatorClass'] = 'arch::OpClassTensorOp' + self.args['ArchTag'] = self.arch + self.args['ThreadblockSwizzle'] = 'threadblock::GemmBatchedIdentityThreadblockSwizzle' + + + for i in range(self.b2b_num): + self.args[helper.var_idx('Stages', i)] = "2" + + self.args['AlignmentA'] = str(8) + self.args['AlignmentB'] = str(8) + self.args['SplitKSerial'] = 'false' + self.args['Operator'] = 'typename DefaultGemmConfiguration::Operator' + self.args['IsBetaZero'] = 'false' + + + def gen_using_kernel(self): + code = "using B2bGemmKernel = typename kernel::DefaultB2bGemm<\n" + code += " " + "ElementA,\n" + code += " " + "LayoutA,\n" + + for i in range(self.b2b_num): + code += " " + helper.var_idx("ElementB", i) + ",\n" + code += " " + helper.var_idx("LayoutB", i) + ",\n" + code += " " + helper.var_idx("ElementC", i) + ",\n" + code += " " + helper.var_idx("LayoutC", i) + ",\n" + code += " " + helper.var_idx("ElementAccumulator", i) + ",\n" + code += " " + helper.var_idx("EpilogueOutputOp", i) + ",\n" + code += " " + helper.var_idx("ThreadblockShape", i) + ",\n" + code += " " + helper.var_idx("WarpShape", i) + ",\n" + + code += " " + "ElementD,\n" + code += " " + "LayoutD,\n" + code += " " + "InstructionShape,\n" + code += " " + "OperatorClass,\n" + code += " " + "ArchTag,\n" + code += " " + "ThreadblockSwizzle,\n" + + for i in range(self.b2b_num): + code += " " + helper.var_idx("Stages", i) + ",\n" + + + code += " " + "AlignmentA,\n" + code += " " + "AlignmentB,\n" + code += " " + "SplitKSerial,\n" + code += " " + "Operator,\n" + code += " " + "IsBetaZero_\n" + + code += ">::B2bGemmKernel;\n\n" + + return code + + def gen_args(self): + + def gen_arg_member(b2b_num): + data_members = [] + + for i in range(b2b_num): + member_type = "GemmCoord" + member_name = "problem_size_" + str(i) + data_members.append((member_type, member_name)) + + member_type = "TensorRef" + member_name = "ref_A0" + data_members.append((member_type, member_name)) + + for i in range(b2b_num): + member_type = "TensorRef" + member_name = "ref_B" + str(i) + data_members.append((member_type, member_name)) + member_type = "TensorRef" + member_name = "ref_C" + str(i) + data_members.append((member_type, member_name)) + + member_type = "TensorRef" + member_name = helper.var_idx("ref_D", b2b_num - 1) + data_members.append((member_type, member_name)) + + for i in range(b2b_num): + member_type = "typename EpilogueOutputOp" + str(i) + "::Params" + member_name = "epilogue" + str(i) + data_members.append((member_type, member_name)) + + data_members.append(('int', 'batch_count')) + + return data_members + + def gen_arg_struct_default_ctor(struct_name, data_members, inital_param_num, inital_value): + constructs_code = gen_ir.indentation + "CUTLASS_HOST_DEVICE\n" + \ + gen_ir.indentation + struct_name + " (): " + for i in range(inital_param_num): + final_param = ',' + if i == inital_param_num - 1: + final_param = '{ }' + constructs_code += data_members[i][1] + inital_value + final_param + + constructs_code += "\n" + return constructs_code + + def gen_arg_struct_ctor(struct_name, data_members): + constructs_code = gen_ir.indentation + "CUTLASS_HOST_DEVICE\n" + \ + gen_ir.indentation + struct_name + " (\n" + cnt = 0 + param_num = len(data_members) + for param in data_members: + final = ',\n' + if cnt == param_num - 1: + final = '\n):\n' + constructs_code += gen_ir.indentation + param[0] + " " + param[1] + "_" + final + cnt += 1 + + cnt = 0 + for param in data_members: + final = '),\n' + if cnt == param_num - 1: + final = ") { }\n" + constructs_code += gen_ir.indentation + param[1] + "(" + param[1] + "_" + final + cnt += 1 + + constructs_code += "\n" + return constructs_code + + # (variable type, variable name) + struct_member = gen_arg_member(self.b2b_num) + self.arg_member = struct_member + + codeBody = "" + for each_member in struct_member: + codeBody += gen_ir.indentation + each_member[0] + " " + each_member[1] + ";\n" + + codeBody += gen_arg_struct_default_ctor("Arguments", struct_member, self.b2b_num, "(0,0,0)") + "\n" + codeBody += gen_arg_struct_ctor("Arguments", struct_member) + "\n" + struct_code = gen_ir.gen_struct("Arguments", codeBody) + return struct_code + + def gen_func_constructs(self): + code = self.gen_class_name +"() {}" + return code + + def gen_func_initialize(self): + code = "Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) {\n" + \ + "// Determine grid shape\n" + \ + "ThreadblockSwizzle threadblock_swizzle;\n" + \ + "cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape(\n" + \ + " args.problem_size_0, \n" + \ + " { ThreadblockShape0::kM, ThreadblockShape0::kN, ThreadblockShape0::kK },\n" + \ + " args.batch_count);\n" + \ + "// Initialize the Params structure\n" + \ + "params_ = typename B2bGemmKernel::Params{\n" + for i in range(self.b2b_num): + code += helper.var_idx(" args.problem_size_", i) + ",\n" + code += " grid_shape,\n" + \ + " args.ref_A0.non_const_ref(),\n" + for i in range(self.b2b_num): + code += helper.var_idx(" args.ref_B", i) + ".non_const_ref(),\n" + code += helper.var_idx(" args.ref_C", i) + ".non_const_ref(),\n" + + code += helper.var_idx(" args.ref_D", self.b2b_num - 1) + ",\n" + for i in range(self.b2b_num): + code += helper.var_idx(" args.epilogue", i) + ",\n" + + code += " args.batch_count\n" + code += "};\n" + \ + "return Status::kSuccess;\n" + \ + "}\n" + return code + + def gen_func_run(self): + code = "Status run(cudaStream_t stream = nullptr) {\n" + \ + "\n" + \ + " ThreadblockSwizzle threadblock_swizzle;\n" + \ + "\n" + \ + " dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape);\n" + \ + " dim3 block(B2bGemmKernel::kThreadCount, 1, 1);\n" + \ + "\n" + \ + " cudaError_t result;\n" + \ + "\n" + \ + " int smem_size = int(sizeof(typename B2bGemmKernel::SharedStorage));\n" + \ + " if (smem_size >= (48 << 10)) {\n" + \ + " result = cudaFuncSetAttribute(Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size);\n" + \ + "\n" + \ + " if (result != cudaSuccess) {\n" + \ + " return Status::kErrorInternal;\n" + \ + " }\n" + \ + " }\n" + \ + " cutlass::Kernel<<>>(params_);\n" + \ + " result = cudaGetLastError();\n" + \ + " return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal;\n" + \ + " }\n" + + return code + def gen_func_operator(self): + opeartor_with_arg_code = "Status operator()(\n" + \ + " Arguments const &args,\n" + \ + " void *workspace = nullptr,\n" + \ + " cudaStream_t stream = nullptr) {\n" + \ + " Status status = initialize(args, workspace);\n" + \ + " \n" + \ + " if (status == Status::kSuccess) {\n" + \ + " status = run(stream);\n" + \ + " }\n" + \ + " return status;\n" + \ + "}\n" + operator_code = "Status operator()(\n" + \ + " cudaStream_t stream = nullptr) {\n" + \ + " Status status = run(stream);\n" + \ + " return status;\n" + \ + "}\n" + return opeartor_with_arg_code + "\n" + operator_code + + def gen_all_func(self): + return self.gen_using_kernel() + "\n" + \ + self.gen_args() + "\n" + \ + self.gen_func_constructs() + "\n" + \ + self.gen_func_initialize() + "\n" + \ + self.gen_func_run() + "\n" + \ + self.gen_func_operator() diff --git a/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_ir.py b/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_ir.py new file mode 100644 index 0000000000..919c777e42 --- /dev/null +++ b/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_ir.py @@ -0,0 +1,249 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import helper + + +indentation = " " + + +def append_word(word): + code = "" + code += word + code += " " + return code + + +def gen_namespace(namespace, codeBody): + code_gen = "namespace " + namespace + " {\n" + code_gen += codeBody + code_gen += "} // namespace " + namespace + "\n" + return code_gen + + +def gen_expression(type, lval, rval = None): + code_gen = "" + code_gen += append_word(type) + code_gen += append_word(lval) + if rval is not None: + code_gen += append_word("=") + code_gen += append_word(rval) + return code_gen + + +def gen_class(name, codeBody, inheritance_code = None): + code_gen = "" + if inheritance_code is None: + code_gen = "class " + name + "{\n" + else: + code_gen = "class " + name + " : "+ inheritance_code + "{\n" + code_gen += codeBody + code_gen += "}; // class " + name + "\n" + return code_gen + + +def gen_struct(name, codeBody, specialized = None): + specialized_code = "" + if specialized is not None: + specialized_code = "<" + specialized + ">" + code_gen = "struct " + name + specialized_code + "{\n" + code_gen += codeBody + code_gen += "}; // struct " + name + "\n" + return code_gen + + +def gen_template_arg(arg_type, arg_name, default_val = None): + rval = None + if default_val is not None: + rval = str(default_val) + + arg_typename = "" + if arg_type is int: + arg_typename = "int" + elif arg_type is bool: + arg_typename = "bool" + else: + arg_typename = "typename" + + internal_arg_name = arg_name + "_" + + code_gen = indentation + code_gen += gen_expression(arg_typename, internal_arg_name, rval) + + return code_gen + + +def gen_template_args(args, set_default = True): + arg_len = len(args) + cnt = 1 + code_gen = "" + for arg_tuple in args: + arg_type = arg_tuple[0] + arg_name = arg_tuple[1] + arg_default_val = None + if len(arg_tuple) == 3 and set_default: + arg_default_val = arg_tuple[2] + + code_gen += gen_template_arg(arg_type, arg_name, arg_default_val) + if cnt != arg_len: + code_gen += ",\n" + cnt += 1 + + return code_gen + + +def gen_template_head(args, set_default = True): + code_gen = "template <\n" + code_gen += gen_template_args(args, set_default) + code_gen += ">\n" + return code_gen + + +def export_template_args(args): + code_gen = "public:\n" + for arg_tuple in args: + code_gen += indentation + arg_type = arg_tuple[0] + arg_name = arg_tuple[1] + internal_arg_name = arg_name + "_" + + typename = "" + if arg_type is int: + typename = "static int const" + elif arg_type is bool: + typename = "static bool const" + else: + typename = "using" + + code_gen += gen_expression(typename, arg_name, internal_arg_name) + code_gen += ";\n" + return code_gen + + +def gen_template_class(class_name, args, codeBody, set_default = True, inheritance_code = None): + code_gen = "" + + code_gen += gen_template_head(args, set_default) + code_gen += gen_class(class_name, export_template_args(args) + codeBody, inheritance_code) + + return code_gen + + +def gen_template_struct(struct_name, args, codeBody, speicalized = None, set_default = True, export_args = True): + code_gen = "" + code_gen += gen_template_head(args, set_default) + code = export_template_args(args) + codeBody + if export_args is False: + code = codeBody + code_gen += gen_struct(struct_name, code , speicalized) + + return code_gen + + +def gen_declare_template_struct(name, *params): + code = name + "<" + cnt = 0 + param_num = len(params) + for param in params: + final = ", " + if cnt == param_num - 1: + final = "" + code += param + final + cnt += 1 + code += ">;\n" + return code + + +def filtered_param(params, name_and_value_pair, keep_ = False): + rtn_template_args = [] + speicalized_template_args = [] + + for param in params: + param_name = "" + if len(param) >= 1: + param_name = param[1] + else: + param_name = param[0] + + hit_flag = False + set_value = "" + for n_v_pair in name_and_value_pair: + + filter_name = n_v_pair[0] + set_value = n_v_pair[1] + + if param_name == (filter_name + "_") or param_name == filter_name : + hit_flag = True + break + + + if hit_flag is False: + rtn_template_args.append(param) + + if hit_flag is True: + speicalized_template_args.append(set_value) + else: + if keep_ is True: + speicalized_template_args.append(param_name + "_") + else: + speicalized_template_args.append(param_name) + + + specialized_template_arg_str = helper.list_2_string(speicalized_template_args) + + return rtn_template_args, specialized_template_arg_str + + +def gen_func(func_name, arg_lists, code_body, only_declare = False, with_cudaStream = True): + code = "void " + func_name + "(\n" + for arg in arg_lists: + arg_tp = arg[0] + arg_nm = arg[1] + code += " " + arg_tp + " " + arg_nm + ",\n" + code += "cudaStream_t stream)" + if only_declare : + return code + code += "{\n" + + code += code_body + "\n" + code += "}\n" + return code + + +def indent_level(code, level = 0): + rtn_code = "" + for i in range(level): + rtn_code += " " + + rtn_code += code + + return rtn_code diff --git a/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_kernel.py b/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_kernel.py new file mode 100644 index 0000000000..2bbaf26b40 --- /dev/null +++ b/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_kernel.py @@ -0,0 +1,476 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import gen_ir +import helper +import gen_threadblock as gen_tb + + +class gen_default_Gemm: + def __init__(self, template_param, gen_class_name, b2b_num, cutlass_deps_root, project_root): + self.gen_class_name = "B2bGemm" + self.template_param = template_param + self.b2b_num = b2b_num + + self.cutlass_deps_root = cutlass_deps_root + self.project_root = project_root + + def gen_B2bMma(self, specialized_template_args): + code = "using B2bMma = typename cutlass::gemm::threadblock::DefaultB2bMma<\n" + code += specialized_template_args + code += ">::ThreadblockB2bMma;\n" + + # print(code) + return code + + def gen_epilogue(self): + epilogue_code = "" + epilogue_code += helper.var_idx("static const int kPartitionsK", self.b2b_num - 1) + helper.var_idx(" = ThreadblockShape", self.b2b_num - 1) + helper.var_idx("::kK / WarpShape", self.b2b_num - 1) + "::kK;\n" + + epilogue_code += "using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp<\n" + epilogue_code += " " + helper.var_idx("ThreadblockShape", self.b2b_num - 1) + ",\n" + epilogue_code += " " + helper.var_idx("typename B2bMma::Operator", self.b2b_num - 1) + ",\n" + epilogue_code += " " + helper.var_idx("kPartitionsK", self.b2b_num - 1) + ",\n" + epilogue_code += " " + helper.var_idx("EpilogueOutputOp", self.b2b_num - 1) + ",\n" + epilogue_code += " " + helper.var_idx("EpilogueOutputOp", self.b2b_num - 1) + "::kCount\n" + epilogue_code += ">::Epilogue;\n" + + epilogue_code += "using B2bGemmKernel = kernel::B2bGemm;\n\n" + + return epilogue_code + + + def gen_include_header(self): + code = ''' +/* Auto Generated code - Do not edit.*/ + +#pragma once +#include \"{cutlass_dir}cutlass/cutlass.h\" + +#include \"{cutlass_dir}cutlass/layout/matrix.h\" +#include \"{cutlass_dir}cutlass/numeric_types.h\" + +#include \"{cutlass_dir}cutlass/epilogue/threadblock/epilogue.h\" +#include \"{cutlass_dir}cutlass/epilogue/thread/linear_combination.h\" + +#include \"{cutlass_dir}cutlass/gemm/gemm.h\" +#include \"{cutlass_dir}cutlass/gemm/kernel/gemm_pipelined.h\" +#include \"{cutlass_dir}cutlass/gemm/threadblock/default_mma_core_sm75.h\" +#include \"{cutlass_dir}cutlass/gemm/threadblock/default_mma_core_sm70.h\" +#include \"{cutlass_dir}cutlass/gemm/threadblock/default_mma_core_sm80.h\" +#include \"{cutlass_dir}cutlass/gemm/threadblock/default_mma_core_simt.h\" +#include \"{cutlass_dir}cutlass/gemm/threadblock/threadblock_swizzle.h\" +#include \"{cutlass_dir}cutlass/epilogue/threadblock/default_epilogue_tensor_op.h\" +#include \"{cutlass_dir}cutlass/epilogue/threadblock/default_epilogue_volta_tensor_op.h\" +#include \"{cutlass_dir}cutlass/epilogue/threadblock/default_epilogue_simt.h\" + +#include \"{cutlass_dir}cutlass/transform/threadblock/predicated_tile_iterator.h\" + +#include \"../kernel/b2b_gemm.h\" +#include \"../threadblock/default_b2b_mma.h\" +'''.format(cutlass_dir=self.cutlass_deps_root) + return code + + def gen_code(self): + gen_using = '' + # Generate default template struct + gen_code = gen_ir.gen_template_struct("Default" + self.gen_class_name, self.template_param,"", speicalized = None, set_default=False) + + + filter_list = [] + filter_list.append(('Stages', 2)) + filter_list.append(("OperatorClass", "arch::OpClassTensorOp")) + filter_list.append(("ArchTag", "arch::Sm75")) + + for i in range(self.b2b_num): + filter_list.append((helper.var_idx("LayoutC", i), "layout::RowMajor")) + + + rtn_template_args, speicalized_template_args = gen_ir.filtered_param(self.template_param, filter_list, keep_= True) + + + B2bMma_code = self.gen_B2bMma(speicalized_template_args) + epilogue_and_rest_code = self.gen_epilogue() + + gen_special_code = gen_ir.gen_template_struct("Default" + self.gen_class_name, rtn_template_args, B2bMma_code + epilogue_and_rest_code, speicalized = speicalized_template_args, set_default=False) + + code = gen_ir.gen_namespace("cutlass", gen_ir.gen_namespace("gemm", gen_ir.gen_namespace("kernel", gen_code + gen_special_code))) + + return self.gen_include_header() + code + + +class gen_Kernel: + def __init__(self, template_param, gen_class_name, b2b_num, cutlass_deps_root, project_root): + self.gen_class_name = "B2bGemm" + self.template_param = template_param + self.b2bnum = b2b_num + + self.cutlass_deps_root = cutlass_deps_root + self.project_root = project_root + + def gen_include_header(self): + code = ''' +#pragma once + +#include \"{cutlass_dir}cutlass/cutlass.h\" +#include \"{cutlass_dir}cutlass/gemm/gemm.h\" +#include \"{cutlass_dir}cutlass/matrix_coord.h\"\n'''.format(cutlass_dir=self.cutlass_deps_root) + return code + + def gen_Params(self): + gen_param = "" + for i in range(self.b2bnum): + gen_param += " " + helper.var_idx("cutlass::gemm::GemmCoord problem_size_", i) + ";\n" + gen_param += " " + "cutlass::gemm::GemmCoord grid_tiled_shape;\n" + gen_param += " " + "typename B2bMma::IteratorA0::Params params_A0;\n" + gen_param += " " + "typename B2bMma::IteratorA0::TensorRef ref_A0;\n" + + for i in range(self.b2bnum): + gen_param += " " + helper.var_idx("typename B2bMma::IteratorB", i) + helper.var_idx("::Params params_B", i) + ";\n" + gen_param += " " + helper.var_idx("typename B2bMma::IteratorB", i) + helper.var_idx("::TensorRef ref_B", i) + ";\n" + if i == self.b2bnum - 1: + gen_param += " " + helper.var_idx("typename Epilogue::OutputTileIterator::Params params_C", i) + ";\n" + gen_param += " " + helper.var_idx("typename Epilogue::OutputTileIterator::TensorRef ref_C", i) + ";\n" + + else: + gen_param += " " + helper.var_idx("typename FusedAddBiasEpilogue", i) + helper.var_idx("::OutputTileIterator::Params params_C", i) + ";\n" + gen_param += " " + helper.var_idx("typename FusedAddBiasEpilogue", i) + helper.var_idx("::OutputTileIterator::TensorRef ref_C", i) + ";\n" + + + + + gen_param += " " + helper.var_idx("typename Epilogue::OutputTileIterator::Params params_D", self.b2bnum - 1) + ";\n" + gen_param += " " + helper.var_idx("typename Epilogue::OutputTileIterator::TensorRef ref_D", self.b2bnum - 1) + ";\n" + + for i in range(self.b2bnum): + gen_param += " " + helper.var_idx("typename OutputOp", i) + helper.var_idx("::Params output_op_", i) + ";\n" + + gen_param += " " + 'int batch_count' + ";\n" + gen_param += " " + 'int gemm_k_iterations_0' + ";\n" + + + return gen_param + + def gen_Memberfunc(self): + code_default = "\nCUTLASS_HOST_DEVICE\n" + code_default += "Params()" + + code_default += " { } \n\n" + + code_construct = "\nCUTLASS_HOST_DEVICE\n" + code_construct += "Params(\n" + + for i in range(self.b2bnum): + code_construct += " " + helper.var_idx("cutlass::gemm::GemmCoord const & problem_size_", i) + ",\n" + + code_construct += " " + "cutlass::gemm::GemmCoord const & grid_tiled_shape,\n" + + code_construct += " " + "typename B2bMma::IteratorA0::TensorRef ref_A0,\n" + + for i in range(self.b2bnum): + code_construct += " " + helper.var_idx("typename B2bMma::IteratorB", i) + helper.var_idx("::TensorRef ref_B", i) + ",\n" + if i == self.b2bnum - 1: + code_construct += " " + helper.var_idx("typename Epilogue::OutputTileIterator::TensorRef ref_C", i) + ",\n" + else: + code_construct += " " + helper.var_idx("typename FusedAddBiasEpilogue", i) + helper.var_idx("::OutputTileIterator::TensorRef ref_C", i) + ",\n" + + code_construct += " " + helper.var_idx("typename Epilogue::OutputTileIterator::TensorRef ref_D", self.b2bnum - 1) + ",\n" + for i in range(self.b2bnum): + code_construct += " " + helper.var_idx("typename OutputOp", i) + helper.var_idx("::Params output_op_", i) + helper.var_idx(" = typename OutputOp", i) + "::Params(),\n" + + code_construct += " " + "int batch_count = 1\n" + + code_construct += "):\n" + + for i in range(self.b2bnum): + code_construct += " " + helper.var_idx("problem_size_", i) + helper.var_idx("(problem_size_", i) + "),\n" + + code_construct += " " + "grid_tiled_shape(grid_tiled_shape),\n" + code_construct += " " + "params_A0(ref_A0.layout()),\n" + code_construct += " " + "ref_A0(ref_A0),\n" + + for i in range(self.b2bnum): + code_construct += " " + helper.var_idx("params_B", i) + helper.var_idx("(ref_B", i) + ".layout()),\n" + code_construct += " " + helper.var_idx("ref_B", i) + helper.var_idx("(ref_B", i) + "),\n" + code_construct += " " + helper.var_idx("params_C", i) + helper.var_idx("(ref_C", i) + ".layout()),\n" + code_construct += " " + helper.var_idx("ref_C", i) + helper.var_idx("(ref_C", i) + "),\n" + + code_construct += " " + helper.var_idx("params_D", self.b2bnum - 1) + helper.var_idx("(ref_D", self.b2bnum - 1) + ".layout()),\n" + code_construct += " " + helper.var_idx("ref_D", self.b2bnum - 1) + helper.var_idx("(ref_D", self.b2bnum - 1) + "),\n" + + for i in range(self.b2bnum): + code_construct += " " + helper.var_idx("output_op_", i) + helper.var_idx("(output_op_", i) + "), \n" + + code_construct += " " + "batch_count(batch_count) {\n" + code_construct += " " + helper.var_idx("gemm_k_iterations_", 0) + helper.var_idx(" = (problem_size_", 0) + helper.var_idx(".k() + B2bMma::Shape", 0) + helper.var_idx("::kK - 1) / B2bMma::Shape", 0) + "::kK;\n" + + code_construct += "}\n" + + return code_default + code_construct + + def gen_using(self): + code_using = "" + + for i in range(self.b2bnum - 1): + code_using += " " + helper.var_idx("using OutputOp", i) + helper.var_idx(" = typename B2bMma::OutputOp", i) + ";\n" + + code_using += " " + helper.var_idx("using OutputOp", self.b2bnum - 1) + " = typename Epilogue::OutputOp;\n" + + for i in range(self.b2bnum - 1): + code_using += " " + helper.var_idx("using FusedAddBiasEpilogue", i) + helper.var_idx(" = typename B2bMma::FusedAddBiasEpilogue", i) +";\n" + + + code_using += " " + "using WarpCount0 = typename B2bMma::WarpCount0;\n" + code_using += " " + "static int const kThreadCount = 32 * WarpCount0::kCount;\n" + + code_using += gen_ir.gen_struct("Params", self.gen_Params() + self.gen_Memberfunc()) + + code_using += "union SharedStorage {\n" + code_using += " " + "typename B2bMma::B2bMmaSharedStorage main_loop;\n" + code_using += " " + "typename Epilogue::SharedStorage epilogue;\n" + code_using += "};\n" + + return code_using + + def gen_can_implement(self): + gen_code = "" + return gen_code + + def gen_operator_and_constr(self): + ctr_code = "CUTLASS_HOST_DEVICE\n" + ctr_code += self.gen_class_name + "() { } \n\n" + operator_code = "CUTLASS_DEVICE\n" + operator_code += "void operator()(Params const ¶ms, SharedStorage &shared_storage) {\n" + operator_code += " " + "ThreadblockSwizzle threadblock_swizzle;\n" + operator_code += " " + "cutlass::gemm::GemmCoord threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.grid_tiled_shape);\n" + operator_code += " " + "int batch_idx = threadblock_tile_offset.k();\n" + operator_code += " " + "if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() ||\n" + operator_code += " " + "params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) {\n" + operator_code += " " + " " + "return;\n" + operator_code += " " + "}\n" + + operator_code += " " + "cutlass::MatrixCoord tb_offset_A0{\n" + operator_code += " " + " " + "threadblock_tile_offset.m() * B2bMma::Shape0::kM,\n" + operator_code += " " + " " + "0\n" + operator_code += " " + "};\n" + + for i in range(self.b2bnum): + operator_code += " " + helper.var_idx("cutlass::MatrixCoord tb_offset_B", i) + "{\n" + operator_code += " " + " " + "0,\n" + operator_code += " " + " " + helper.var_idx("threadblock_tile_offset.n() * B2bMma::Shape", i) + "::kN\n" + operator_code += " " + "};\n" + + operator_code += " " + "int thread_idx = threadIdx.x;\n\n" + + operator_code += " " + "MatrixCoord threadblock_offset(\n" + operator_code += " " + " " + helper.var_idx("threadblock_tile_offset.m() * B2bMma::Shape", self.b2bnum - 1) + "::kM,\n" + operator_code += " " + " " + helper.var_idx("threadblock_tile_offset.n() * B2bMma::Shape", self.b2bnum - 1) + "::kN\n" + operator_code += " " + ");\n" + + operator_code += " " + "typename B2bMma::IteratorA0 iterator_A0(\n" + operator_code += " " + " " + "params.params_A0,\n" + operator_code += " " + " " + "params.ref_A0.data(),\n" + operator_code += " " + " " + "params.problem_size_0.mk(),\n" + operator_code += " " + " " + "thread_idx,\n" + operator_code += " " + " " + "tb_offset_A0);\n" + + operator_code += " " + "iterator_A0.add_pointer_offset(batch_idx * params.problem_size_0.m() * params.problem_size_0.k());\n\n" + + + for i in range (self.b2bnum): + operator_code += " " + helper.var_idx("typename B2bMma::IteratorB", i ) + helper.var_idx(" iterator_B", i) + "(\n" + operator_code += " " + " " + helper.var_idx("params.params_B", i) + ",\n" + operator_code += " " + " " + helper.var_idx("params.ref_B", i) + ".data(),\n" + operator_code += " " + " " + helper.var_idx("params.problem_size_", i) + ".kn(),\n" + operator_code += " " + " " + "thread_idx,\n" + operator_code += " " + " " + helper.var_idx("tb_offset_B", i) + ");\n" + operator_code += " " + helper.var_idx("iterator_B", i) + helper.var_idx(".add_pointer_offset(batch_idx * params.problem_size_", i) + helper.var_idx(".n() * params.problem_size_", i) + ".k());\n\n" + + + for i in range (self.b2bnum - 1): + operator_code += " " + helper.var_idx("typename FusedAddBiasEpilogue", i ) + helper.var_idx("::OutputTileIterator iterator_C", i) + "(\n" + operator_code += " " + " " + helper.var_idx("params.params_C", i) + ",\n" + operator_code += " " + " " + helper.var_idx("params.ref_C", i) + ".data(),\n" + operator_code += " " + " " + helper.var_idx("params.problem_size_" , i) + ".mn(),\n" + operator_code += " " + " " + "thread_idx,\n" + operator_code += " " + " " + "threadblock_offset" + ");\n" + operator_code += " " + helper.var_idx("int ref_C", i) + helper.var_idx("_stride = params.ref_C", i) + ".stride()[0];\n" + operator_code += " " + helper.var_idx("iterator_C", i) + helper.var_idx(".add_pointer_offset(batch_idx * params.problem_size_", i) + helper.var_idx(".n() * (ref_C", i) + helper.var_idx("_stride == 0 ? 1 : params.problem_size_", i) + ".m()));\n\n" + + + for i in range (self.b2bnum - 1): + operator_code += " " + helper.var_idx("FusedAddBiasEpilogue", i ) + helper.var_idx(" epilogue_", i ) + ";\n" + + + operator_code += " " + "int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0);\n" + operator_code += " " + "int lane_idx = threadIdx.x % 32;\n" + + for i in range (self.b2bnum - 1): + operator_code += " " + helper.var_idx("OutputOp", i) + helper.var_idx(" output_op_", i) + helper.var_idx("(params.output_op_", i) + ");\n" + + operator_code += " " + "B2bMma b2bMma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx);\n" + + operator_code += " " + "typename B2bMma::FragmentC0 src_accum;\n" + operator_code += " " + helper.var_idx("typename B2bMma::FragmentC", self.b2bnum - 1)+ " accumulators;\n" + + operator_code += " " + "src_accum.clear();\n" + operator_code += " " + "accumulators.clear();\n" + operator_code += " " + "b2bMma(params.gemm_k_iterations_0, accumulators, iterator_A0, " + + for i in range(self.b2bnum): + operator_code += helper.var_idx("iterator_B", i) + ", " + + operator_code += "src_accum" + if self.b2bnum != 1: + operator_code += ", " + for i in range(self.b2bnum - 1): + operator_code += helper.var_idx("output_op_", i) + ", " + + for i in range(self.b2bnum - 1): + operator_code += helper.var_idx("epilogue_", i) + ", " + + for i in range(self.b2bnum - 1): + final = ", " + if i == self.b2bnum - 2: + final ="" + operator_code += helper.var_idx("iterator_C", i) + final + operator_code += ");\n" + + operator_code += " " + helper.var_idx("OutputOp", self.b2bnum - 1) + helper.var_idx(" output_op_", self.b2bnum - 1) + helper.var_idx("(params.output_op_", self.b2bnum - 1) + ");\n" + operator_code += " " + "threadblock_tile_offset = threadblock_swizzle.get_tile_offset(params.grid_tiled_shape);\n" + + + + operator_code += " " + helper.var_idx("typename Epilogue::OutputTileIterator iterator_C", self.b2bnum - 1) + "(\n" + operator_code += " " + " " + helper.var_idx("params.params_C", self.b2bnum - 1) + ",\n" + operator_code += " " + " " + helper.var_idx("params.ref_C", self.b2bnum - 1) + ".data(),\n" + operator_code += " " + " " + helper.var_idx("params.problem_size_", self.b2bnum - 1) + ".mn(),\n" + operator_code += " " + " " + "thread_idx,\n" + operator_code += " " + " " + "threadblock_offset\n" + operator_code += " " + ");\n" + operator_code += " " + helper.var_idx("int ref_C", self.b2bnum - 1) + helper.var_idx("_stride = params.ref_C", self.b2bnum - 1) + ".stride()[0];\n" + + operator_code += " " + helper.var_idx("iterator_C", self.b2bnum - 1) + helper.var_idx(".add_pointer_offset(batch_idx * params.problem_size_", self.b2bnum - 1) + helper.var_idx(".n() * (ref_C", self.b2bnum - 1) + helper.var_idx("_stride == 0 ? 1 : params.problem_size_", self.b2bnum - 1) + ".m()));\n\n" + + operator_code += " " + helper.var_idx("typename Epilogue::OutputTileIterator iterator_D", self.b2bnum - 1) + "(\n" + operator_code += " " + " " + helper.var_idx("params.params_D", self.b2bnum - 1) + ",\n" + operator_code += " " + " " + helper.var_idx("params.ref_D", self.b2bnum - 1) + ".data(),\n" + operator_code += " " + " " + helper.var_idx("params.problem_size_", self.b2bnum - 1) + ".mn(),\n" + operator_code += " " + " " + "thread_idx,\n" + operator_code += " " + " " + "threadblock_offset\n" + operator_code += " " + ");\n" + operator_code += " " + helper.var_idx("iterator_D", self.b2bnum - 1) + helper.var_idx(".add_pointer_offset(batch_idx * params.problem_size_", self.b2bnum - 1) + helper.var_idx(".n() * params.problem_size_", self.b2bnum - 1) + ".m());\n\n" + + + operator_code += " " + "Epilogue epilogue(\n" + operator_code += " " + " " + "shared_storage.epilogue,\n" + operator_code += " " + " " + "thread_idx,\n" + operator_code += " " + " " + "warp_idx,\n" + operator_code += " " + " " + "lane_idx\n" + operator_code += " " + ");\n" + + operator_code += " " + "epilogue(" + operator_code += helper.var_idx("output_op_", self.b2bnum - 1) + ", " + operator_code += helper.var_idx("iterator_D", self.b2bnum - 1) + ", " + operator_code += "accumulators, " + operator_code += helper.var_idx("iterator_C", self.b2bnum - 1) + ");\n" + operator_code += "}\n" + + return ctr_code + operator_code + + def gen_include_header(self): + code = ''' +#pragma once + +#include \"{cutlass_dir}cutlass/cutlass.h\" + +#include \"{cutlass_dir}cutlass/gemm/gemm.h\" +#include \"{cutlass_dir}cutlass/matrix_coord.h\" +#include \"{cutlass_dir}cutlass/semaphore.h\" +'''.format(cutlass_dir=self.cutlass_deps_root) + return code + def gen_code(self): + + template_param = [] + template_param.append(("typename", "B2bMma")) + template_param.append(("typename", "Epilogue")) + template_param.append(("typename", "ThreadblockSwizzle")) + template_param.append((bool, "SplitKSerial")) + + code_body = "" + code_body += self.gen_using() + code_body += self.gen_operator_and_constr() + + struct_code = gen_ir.gen_template_struct(self.gen_class_name, template_param, code_body) + code = self.gen_include_header() + code += gen_ir.gen_namespace("cutlass", gen_ir.gen_namespace("gemm", gen_ir.gen_namespace("kernel", struct_code))) + + return self.gen_include_header() + code + + + +class gen_kernel: + def __init__(self, template_param, gen_class_name, b2b_num, output_dir, cutlass_deps_root, project_root): + self.template_param = template_param + + self.gen_class_name = "B2bGemm" + self.gen_kernel_name = gen_class_name + "Kernel" + self.template_args = [] + + self.cutlass_deps_root = cutlass_deps_root + self.project_root = project_root + + self.gen_default_b2b_gemm = gen_default_Gemm(template_param, gen_class_name, b2b_num, cutlass_deps_root, project_root) + self.gen_Kerenl = gen_Kernel(template_param, gen_class_name, b2b_num, cutlass_deps_root, project_root) + + # Include gen_threadBlock + self.gen_threadBlock = gen_tb.gen_threadblock(template_param, gen_class_name, b2b_num, output_dir, cutlass_deps_root, project_root) + + self.file_dir = output_dir + "/kernel/" + + def gen_code(self, first_use_1stage): + + default_b2b_gemm = self.gen_default_b2b_gemm.gen_code() + + print("[INFO]: Gen kernel code [default_b2b_gemm.h]output Dir: is ", self.file_dir) + + with open(self.file_dir + "default_b2b_gemm.h", "w+") as f: + f.write(default_b2b_gemm) + + kernel = self.gen_Kerenl.gen_code() + print("[INFO]: Gen kernel code [b2b_gemm.h]output Dir: is ", self.file_dir) + + with open(self.file_dir + "b2b_gemm.h", "w+") as f: + f.write(kernel) + + # Call code to gen threadblock + self.gen_threadBlock.gen_code(first_use_1stage) diff --git a/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_sample.py b/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_sample.py new file mode 100644 index 0000000000..6474d95c5d --- /dev/null +++ b/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_sample.py @@ -0,0 +1,232 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import helper +import gen_ir as ir + +class gen_test: + def __init__(self, fuse_gemm_info, gen_class_name, user_header_file, output_dir = "../"): + self.fuse_gemm_info = fuse_gemm_info + self.gen_class_name = gen_class_name + self.user_header_file = user_header_file + self.sample_dir = output_dir + self.b2b_num = len(fuse_gemm_info) + + def gen_cpp_sample(self): + code = "/* Auto Generated code - Do not edit.*/\n" + code += "#include \n" + + code += "#include \"cutlass/gemm/device/gemm_batched.h\" \n" + code += "#include \"cutlass/cutlass.h\" \n" + + code += "#include \"../cutlass_irrelevant.h\" \n" + code += "#include \"../cutlass_verify.h\" \n" + + code += "#include \"leaky_bias.h\" \n" + + code += "#include \"utils.h\" \n" + + + + code += "int main(int args, char * argv[]) {\n" + code += " " + "int M = atoi(argv[1]);\n" + code += " " + "int K0 = " + str(self.fuse_gemm_info[0]['mnk'][0]) + ";\n" + code += " " + "if(args == 3);\n" + code += " " + " " + "K0 = atoi(argv[2]);\n" + code += " " + "int B = 1;\n" + code += " " + "if(args == 4);\n" + code += " " + " " + "B = atoi(argv[3]);\n" + + code += " " + "srand(1234UL);\n" + code += " " + "int device_id = 0;\n" + code += " " + "cudaGetDevice(&device_id);\n" + code += " " + "cudaDeviceProp prop;\n" + code += " " + "cudaGetDeviceProperties(&prop, device_id);\n" + code += " " + "int sm = prop.major *10 + prop.minor;\n" + code += "using ElementCompute = cutlass::half_t;\n" + + for i in range(self.b2b_num): + code += " " + helper.var_idx("ElementCompute alpha", i) + " = ElementCompute(1);\n" + addbias = helper.get_epilogue_add_bias_or_not( self.fuse_gemm_info[i]) + if addbias: + code += " " + helper.var_idx("ElementCompute beta", i) + " = ElementCompute(1);\n" + else: + code += " " + helper.var_idx("ElementCompute beta", i) + " = ElementCompute(0);\n" + + code += " " + "size_t flops = 0;\n" + + for i in range(self.b2b_num): + m = self.fuse_gemm_info[i]['mnk'][0] + n = self.fuse_gemm_info[i]['mnk'][1] + k = self.fuse_gemm_info[i]['mnk'][2] + + bias_shape = helper.get_epilogue_bias_shape(self.fuse_gemm_info[i]) + + this_k = "K0" + if (i > 0): + this_k = str(k) + + code += " " + "flops += size_t(2) * size_t(M) * size_t(B) * " + "size_t(" + str(n) + ") * size_t(" + this_k + ");\n" + + code += " " + helper.var_idx("cutlass::gemm::GemmCoord problem_size_", i) + "(" + "M" + ", " + str(n) + ", " + this_k + ");\n" + + code += " " + helper.var_idx("memory_unit Mat_A", i) + helper.var_idx("(B * problem_size_", i) + helper.var_idx(".m() * problem_size_", i) + ".k());\n" + code += " " + helper.var_idx("memory_unit Mat_B", i) + helper.var_idx("(B * problem_size_", i) + helper.var_idx(".n() * problem_size_", i) + ".k());\n" + code += " " + helper.var_idx("memory_unit Mat_C", i) + "(B * " + str(bias_shape[0]) + " * " + str(bias_shape[1]) + ");\n" + code += " " + helper.var_idx("memory_unit Mat_D_cutlass_ref", i) + helper.var_idx("(B * problem_size_", i) + helper.var_idx(".m() * problem_size_", i) + ".n());\n" + + code += " " + helper.var_idx("Mat_A", i) + ".init();\n" + code += " " + helper.var_idx("Mat_B", i) + ".init();\n" + code += " " + helper.var_idx("Mat_C", i) + ".init();\n" + + + + code += " " + helper.var_idx("memory_unit Mat_D", self.b2b_num - 1) + helper.var_idx("(B * problem_size_", i) + helper.var_idx(".m() * problem_size_",self.b2b_num - 1) + ".n());\n" + + params = [] + params.append("M") + params.append("B") + + params.append("Mat_A0.device_ptr") + for i in range(self.b2b_num): + params.append(helper.var_idx("Mat_B", i) + ".device_ptr") + params.append(helper.var_idx("Mat_C", i) + ".device_ptr") + if i != self.b2b_num-1: + params.append(helper.var_idx("Mat_D_cutlass_ref", i) + ".device_ptr") + params.append(helper.var_idx("Mat_D", self.b2b_num - 1) + ".device_ptr") + + code += " " + "Param arguments = {\n" + code += " " + " " + "M,\n" + code += " " + " " + "K0,\n" + code += " " + " " + "B,\n" + + code += " " + " " + "reinterpret_cast(Mat_A0.device_ptr),\n" + cnt = 1 + for i in range(self.b2b_num): + bias_flag = helper.get_epilogue_add_bias_or_not( self.fuse_gemm_info[i]) + code += " " + " " + "reinterpret_cast(" + helper.var_idx("Mat_B", i) + ".device_ptr" + "),\n" + cnt += 1 + if bias_flag: + code += " " + " " + "reinterpret_cast(" + helper.var_idx("Mat_C", i) + ".device_ptr" + "),\n" + cnt += 1 + else: + code += " " + " " + "reinterpret_cast(NULL),\n" + + epilogue_args = helper.get_epilogue_args(self.fuse_gemm_info[i]) + acc_tp = helper.get_epilogue_compute_tp(self.fuse_gemm_info[i]) + for arg in epilogue_args: + arg_value = str(arg[2]) + + code += " " + " " + helper.type_2_cutlass_type(acc_tp) + "(" + arg_value + "),\n" + + if i != self.b2b_num - 1: + code += " " + " " + "reinterpret_cast(" + helper.var_idx("Mat_D_cutlass_ref", i) + ".device_ptr" + "),\n" + else: + code += " " + " " + "reinterpret_cast(" + helper.var_idx("Mat_D", i) + ".device_ptr" + ")};\n" + + + + + code += " " + "TI(FUSED_CUTLASS);\n" + code += " " + "for(int i = 0; i < 100; i++){\n" + code += " " + " " + "one_api(arguments, sm, NULL);\n" + + code += " " + "}\n" + code += " " + "TO(FUSED_CUTLASS, \"FUSED_CUTLASS\", 100);\n" + + code += "\n" + + for i in range(self.b2b_num): + code_this = "" + + N_str = str(self.fuse_gemm_info[i]['mnk'][1]) + + code_this += " " + helper.var_idx("typename Gemm", i) + helper.var_idx("::Arguments arguments_", i) + "{\n" + code_this += " " + " " + helper.var_idx("problem_size_", i) + ",\n" + ldmA = str(self.fuse_gemm_info[i]['mnk'][2]) + if i == 0: + ldmA = "K0" + ldmB = str(self.fuse_gemm_info[i]['mnk'][2]) + if i == 0: + ldmB = "K0" + ldmC = str(self.fuse_gemm_info[i]['mnk'][1]) + + ldmBias = str(helper.get_epilogue_bias_ldm(self.fuse_gemm_info[i])) + + if self.fuse_gemm_info[i]['A_format'] is 'Col': + ldmA = "M" + if self.fuse_gemm_info[i]['B_format'] is 'Row': + ldmB = str(self.fuse_gemm_info[i]['mnk'][1]) + if self.fuse_gemm_info[i]['C_format'] is 'Col': + ldmC = "M" + + if i == 0: + code_this += " " + " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['A_tp']) + "*>(" + helper.var_idx("Mat_A", i) + ".device_ptr), " + ldmA + "}, " + "M * " + ldmA + ",\n" + else: + code_this += " " + " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['A_tp']) + "*>(" + helper.var_idx("Mat_D_cutlass_ref", i - 1) + ".device_ptr), " + ldmA + "}, " + "M * " + ldmA + ",\n" + + code_this += " " + " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['B_tp']) + "*>(" + helper.var_idx("Mat_B", i) + ".device_ptr), " + ldmB + "}, " + N_str + " * " + ldmB + ",\n" + + M_bias = str(helper.get_epilogue_bias_shape(self.fuse_gemm_info[i])[0]) + + code_this += " " + " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_tp']) + "*>(" + helper.var_idx("Mat_C", i) + ".device_ptr), " + ldmBias + "}, " + M_bias + " * " + N_str + ",\n" + code_this += " " + " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_tp']) + "*>(" + helper.var_idx("Mat_D_cutlass_ref", i) + ".device_ptr), " + ldmC + "}, " + "M * " + ldmC + ",\n" + code_this += " " + " " + "{ " + helper.var_idx("alpha", i) + ", " + helper.var_idx("beta", i) + for epilogue_arg in helper.get_epilogue_args(self.fuse_gemm_info[i]): + arg_value = str(epilogue_arg[2]) + code_this += ", " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + "(" + str(arg_value) + ")" + code_this += " " + " },\n" + code_this += " " + " " + "B};\n" + + code += code_this + + + + code += " " + "TI(UNFUSED_CUTLASS);\n" + code += " " + "for(int i = 0; i < 100; i++){\n" + code += " " + " " + self.gen_class_name + "_verify(\n" + for i in range(self.b2b_num): + code += " " + " " + " " + helper.var_idx("arguments_", i) + ",\n" + code += " " + " " + " " + "NULL);\n" + + code += " " + "}\n" + code += " " + "TO(UNFUSED_CUTLASS, \"UNFUSED_CUTLASS\", 100);\n" + + code += " " + helper.var_idx("Mat_D_cutlass_ref", self.b2b_num - 1) + ".d2h();\n" + code += " " + helper.var_idx("Mat_D", self.b2b_num - 1) + ".d2h();\n" + code += " " + helper.var_idx("check_result(Mat_D_cutlass_ref", self.b2b_num - 1) + helper.var_idx(".host_ptr, Mat_D", self.b2b_num - 1) \ + + helper.var_idx(".host_ptr, Mat_D", self.b2b_num - 1) + ".elements);\n" + + code += "\n\n}\n" + + with open(self.sample_dir + "sample.cu", "w+") as f: + f.write(code) diff --git a/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_threadblock.py b/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_threadblock.py new file mode 100644 index 0000000000..91f9ef3ccb --- /dev/null +++ b/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_threadblock.py @@ -0,0 +1,1013 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import gen_ir +import helper + + +class gen_default_b2b_mma: + def __init__(self, template_param, gen_class_name, b2b_num,cutlass_deps_root, project_root): + self.gen_class_name = "DefaultB2bMma" + self.template_param = template_param + self.b2b_num = b2b_num + + self.cutlass_deps_root = cutlass_deps_root + self.project_root = project_root + + def gen_include_header(self): + code = ''' +/* Auto Generated code - Do not edit.*/ + +#pragma once + +#include \"{cutlass_dir}cutlass/cutlass.h\" +#include \"{cutlass_dir}cutlass/numeric_types.h\" +#include \"{cutlass_dir}cutlass/arch/arch.h\" + +#include \"{cutlass_dir}cutlass/transform/threadblock/predicated_tile_iterator.h\" +#include \"{cutlass_dir}cutlass/transform/threadblock/predicated_tile_iterator_2dthreadtile.h\" +#include \"{cutlass_dir}cutlass/gemm/threadblock/default_mma_core_sm70.h\" +#include \"{cutlass_dir}cutlass/gemm/threadblock/default_mma_core_sm75.h\" +#include \"{cutlass_dir}cutlass/gemm/threadblock/default_mma_core_sm80.h\" + +#include \"../threadblock/b2b_mma_pipelined.h\" +#include \"../../fixed_impl/epilogue/threadblock/fused_bias_act_epilogue.h\" +#include \"../../fixed_impl/epilogue/threadblock/default_bias_act_epilogue_tensor_op.h\" +#include \"../../fixed_impl/gemm/warp/mma_tensor_op_fragment_iterator_without_output_op.h\" +'''.format(cutlass_dir=self.cutlass_deps_root) + return code + + + def gen_using_MmaCore(self, stage): + threadBlockShape = "ThreadblockShape" + warpShape = "WarpShape" + instrunctionShape = "InstructionShape" + Mma_typename = "typename cutlass::gemm::threadblock::DefaultMmaCore" + + + gen_code = "" + + for i in range(self.b2b_num): + code_using = "using MmaCore" + str(i) + gen_code += code_using + " = " + gen_ir.gen_declare_template_struct(Mma_typename, \ + helper.var_idx(threadBlockShape, i), helper.var_idx(warpShape, i), instrunctionShape, \ + "ElementA", "LayoutA", \ + helper.var_idx("ElementB", i), helper.var_idx("LayoutB", i), \ + helper.var_idx("ElementAccumulator", i), "layout::RowMajor", \ + "OperatorClass", str(stage), "Operator") + return gen_code + + def gen_using_FusedAddBiasEpilogue(self): + gen_code = "" + for i in range(self.b2b_num - 1): + code_using = helper.var_idx("using FusedAddBiasEpilogue", i) + epilogue_name = "typename cutlass::epilogue::threadblock::DefaultFusedBiasActEpilogueTensorOp" + template_args = helper.var_idx("::Epilogue" + + gen_code += code_using + " = " + epilogue_name + template_args + ";\n" + + return gen_code + + + def gen_using_Iterator(self): + code_using = "using IteratorA0" + iterator_typename = "cutlass::transform::threadblock::PredicatedTileIterator" + MmaCore = "MmaCore0" + matrix_shape = "cutlass::MatrixShape<" + MmaCore + "::Shape::kM, " + MmaCore + "::Shape::kK>" + iterator_map = "typename " + MmaCore + "::IteratorThreadMapA" + gen_code = code_using + " = " + gen_ir.gen_declare_template_struct(iterator_typename, \ + matrix_shape, "ElementA", "LayoutA", "1", iterator_map, "AlignmentA_") + + for i in range(self.b2b_num): + code_using = "using IteratorB" + str(i) + iterator_typename = "cutlass::transform::threadblock::PredicatedTileIterator" + MmaCore = "MmaCore" + str(i) + matrix_shape = "cutlass::MatrixShape<" + MmaCore + "::Shape::kK, " + MmaCore + "::Shape::kN>" + iterator_map = "typename " + MmaCore + "::IteratorThreadMapB" + + gen_code += code_using + " = " + gen_ir.gen_declare_template_struct(iterator_typename, \ + matrix_shape, helper.var_idx("ElementB", i), helper.var_idx("LayoutB", i), "0", iterator_map, "AlignmentB_") + + return gen_code + + def gen_fragment_iterator(self): + gen_code = "using AccumulatorLayout = cutlass::layout::ColumnMajor;\n" + + for i in range(1, self.b2b_num): + code_using = "using FragmentIteratorA" + str(i) + iterator_typename = "cutlass::gemm::warp::MmaTensorOpPureFragmentIterator" + curr_MmaCore = "MmaCore" + str(i) + prev_MmaCore = "MmaCore" + str(i - 1) + Matrix_shape_curr = "cutlass::MatrixShape<" + curr_MmaCore + "::WarpShape::kM, " + curr_MmaCore + "::InstructionShape::kK>" + Matrix_shape_prev = "cutlass::MatrixShape<" + prev_MmaCore + "::WarpShape::kM, " + prev_MmaCore + "::WarpShape::kN>" + Curr_shape_kK = curr_MmaCore + "::Shape::kK" + + gen_code += code_using + " = " + gen_ir.gen_declare_template_struct(iterator_typename, \ + Matrix_shape_curr, Matrix_shape_prev, Curr_shape_kK, \ + helper.var_idx("ElementAccumulator", i-1), "ElementA", \ + "AccumulatorLayout", "InstructionShape_", "true") + + return gen_code + + def gen_threadblockmma(self): + code_using = "using ThreadblockB2bMma" + iterator_typename = "cutlass::gemm::threadblock::B2bMmaPipelined" + + MmaPipelined_param_Mma0_shape = "typename MmaCore0::Shape" + MmaPipelined_param_Mma0_iteratorA = "IteratorA0" + MmaPipelined_param_Mma0_smemIteratorA = "typename MmaCore0::SmemIteratorA" + MmaPipelined_param_Mma0_iteratorB = "IteratorB0" + MmaPipelined_param_Mma0_smemIteratorB = "typename MmaCore0::SmemIteratorB" + + MmaPipelined_param_list = MmaPipelined_param_Mma0_shape + ", " + MmaPipelined_param_Mma0_iteratorA + ", " + MmaPipelined_param_Mma0_smemIteratorA + ", " + MmaPipelined_param_Mma0_iteratorB + ", " + MmaPipelined_param_Mma0_smemIteratorB + ", " + + for i in range(1, self.b2b_num): + MmaPipelined_param_Mma_shape = "typename MmaCore" + str(i) + "::Shape" + MmaPipelined_param_Mma_iteratorA = "FragmentIteratorA" + str(i) + MmaPipelined_param_Mma_iteratorB = "IteratorB" + str(i) + MmaPipelined_param_Mma_smemIteratorB = "typename MmaCore" + str(i) + "::SmemIteratorB" + + MmaPipelined_param_list += MmaPipelined_param_Mma_shape + ", " + MmaPipelined_param_Mma_iteratorA + ", " + MmaPipelined_param_Mma_iteratorB + ", " + MmaPipelined_param_Mma_smemIteratorB + ", " + + MmaPipelined_param_list += "ElementAccumulator0, layout::RowMajor, " + + for i in range(self.b2b_num - 1): + epilogue_name = "EpilogueOutputOp" + str(i) + MmaPipelined_param_list += epilogue_name + ", " + + for i in range(self.b2b_num - 1): + epilogue_name = "FusedAddBiasEpilogue" + str(i) + MmaPipelined_param_list += epilogue_name + ", " + + for i in range(self.b2b_num): + MmaPolicy = "typename MmaCore" + str(i) + "::MmaPolicy" + MmaPipelined_param_list += MmaPolicy + ", " + + + cnt = 0 + for i in range(self.b2b_num): + MmaStage = helper.var_idx("Stages", i) + final = ", " + if cnt == self.b2b_num - 1: + final = "" + MmaPipelined_param_list += MmaStage + final + cnt += 1 + + gen_code = code_using + " = " + gen_ir.gen_declare_template_struct(iterator_typename, MmaPipelined_param_list) + + return gen_code + + + + def gen_code(self): + gen_using = '' + # Generate default template struct + gen_code = gen_ir.gen_template_struct(self.gen_class_name, self.template_param, "", speicalized = None, set_default=False) + + # Generate specialized template struct + + mmacore_codebody = self.gen_using_MmaCore(2) + iterator_codebody = self.gen_using_Iterator() + fragment_iterator_codebody = self.gen_fragment_iterator() + epilogue_iterator_codebody = self.gen_using_FusedAddBiasEpilogue() + threadBlockMma = self.gen_threadblockmma() + specialized_code = mmacore_codebody + iterator_codebody + fragment_iterator_codebody + epilogue_iterator_codebody + threadBlockMma + + # Specialize layout C -> cutlass::layout::RowMajor + + rtn_template_args, speicalized_template_args = gen_ir.filtered_param(self.template_param, [ ('LayoutD', "cutlass::layout::RowMajor")], keep_= True) + + gen_speical_code = gen_ir.gen_template_struct(self.gen_class_name, rtn_template_args, specialized_code, speicalized = speicalized_template_args, set_default=False) + code = gen_ir.gen_namespace("cutlass", gen_ir.gen_namespace("gemm", gen_ir.gen_namespace("threadblock", gen_code + gen_speical_code))) + + return self.gen_include_header() + code + + +class gen_b2b_mme_pipelined: + def __init__(self, template_param, gen_class_name, b2b_num, cutlass_deps_root, project_root): + self.gen_class_name = "B2bMmaPipelined" + self.template_param = template_param + self.b2b_num = b2b_num + self.cutlass_deps_root = cutlass_deps_root + self.project_root = project_root + + + def gen_include_header(self): + code = ''' +#pragma once + +#include \"{cutlass_dir}cutlass/cutlass.h\" +#include \"{cutlass_dir}cutlass/array.h\" +#include \"{cutlass_dir}cutlass/aligned_buffer.h\" +#include \"{cutlass_dir}cutlass/numeric_conversion.h\" + +#include \"{cutlass_dir}cutlass/numeric_types.h\" +#include \"{cutlass_dir}cutlass/matrix_shape.h\" + +#include \"{cutlass_dir}cutlass/gemm/gemm.h\" +#include \"{cutlass_dir}cutlass/gemm/warp/mma_tensor_op_fragment_iterator.h\" + +#include \"../threadblock/b2b_mma_base.h\"\n'''.format(cutlass_dir = self.cutlass_deps_root) + return code + + + def gen_using(self): + code_using = "using FragmentA0 = typename IteratorA0::Fragment;\n" + + code_using += "using Base = B2bMmaBase<" + for i in range(self.b2b_num): + code_using += helper.var_idx("Shape", i) + "_, " + for i in range(self.b2b_num): + code_using += helper.var_idx("Policy", i) + "_, " + for i in range(self.b2b_num): + code_using += helper.var_idx("Stage", i) + "_, " + code_using = code_using[: -2] + ">;\n" + + + for i in range(self.b2b_num): + code_using += helper.var_idx("using FragmentB", i) + helper.var_idx(" = typename IteratorB", i) + "::Fragment;\n" + code_using += helper.var_idx("using FragmentC", i) + helper.var_idx(" = typename Policy", i) + "::Operator::FragmentC;\n" + code_using += helper.var_idx("using Operator", i) + helper.var_idx(" = typename Policy", i) + "::Operator;\n" + + for i in range(self.b2b_num - 1): + code_using += helper.var_idx("using IteratorC", i) + helper.var_idx(" = typename FusedAddBiasEpilogue", i) + "::OutputTileIterator;\n" + + code_using += "using ArchTag = typename Policy0::Operator::ArchTag;\n" + code_using += "static ComplexTransform const kTransformA0 = Operator0::kTransformA;\n" + + for i in range(self.b2b_num): + code_using += helper.var_idx("static ComplexTransform const kTransformB", i) + helper.var_idx(" = Operator", i) + "::kTransformB;\n" + + code_using += "private:\n" + code_using += "using WarpFragmentA0 = typename Operator0::FragmentA;\n" + code_using += "using WarpFragmentB0 = typename Operator0::FragmentB;\n" + + for i in range(1, self.b2b_num): + code_using += helper.var_idx("using WarpFragmentA", i) + helper.var_idx(" = typename FragmentIteratorA", i) + "::Fragment;\n" + code_using += helper.var_idx("using WarpFragmentB", i) + helper.var_idx(" = typename Operator", i) + "::FragmentB;\n" + + code_using += "protected:\n" + + code_using += "SmemIteratorA0 smem_iterator_A_;\n" + + for i in range(self.b2b_num): + code_using += helper.var_idx("SmemIteratorB", i) + helper.var_idx(" smem_iterator_B", i) + "_;\n" + + return code_using + + + def gen_operator(self, first_use_1stage = False): + code = "" + def gen_operator_param(b2b_num): + param_code = "" + param_code += "int gemm_k_iterations_0,\n" + param_code += helper.var_idx("FragmentC", b2b_num-1) + helper.var_idx(" &accum", b2b_num-1) + ",\n" + param_code += "IteratorA0 iterator_A,\n" + + for i in range(b2b_num): + param_code += helper.var_idx("IteratorB", i) + " " + helper.var_idx("iterator_B", i) + ",\n" + + param_code += "FragmentC0 const &src_accum, \n" + + for i in range(b2b_num - 1): + param_code += helper.var_idx("OutputOp", i) + " " + helper.var_idx("output_op_", i) + ",\n" + for i in range(b2b_num - 1): + param_code += helper.var_idx("FusedAddBiasEpilogue", i) + " " + helper.var_idx("epilogue_", i) + ",\n" + for i in range(b2b_num - 1): + param_code += helper.var_idx("IteratorC", i) + " " + helper.var_idx("iterator_C", i) + ",\n" + + + param_code += "TransformA0 transform_A0 = TransformA0(), \n" + + for i in range(b2b_num): + final = "(),\n" + if i == b2b_num - 1: + final = "()\n" + param_code += helper.var_idx("TransformB", i) + " " + helper.var_idx("transform_B", i) + " = " +helper.var_idx("TransformB", i) + final + + return param_code + + + + def gen_first_gemm_1stage(b2b_num): + accu_code = " FragmentC0 accum0 = src_accum;\n" + if b2b_num == 1: + accu_code = " accum0 = src_accum;\n" + + code ="\ +\n\ + FragmentA0 tb_frag_A;\n\ + FragmentB0 tb_frag_B0;\n\ +\n\ + int smem_write_stage_idx = 1;\n\ +\n\ + tb_frag_A.clear();\n\ + tb_frag_B0.clear();\n\ +\n\ + // The last kblock is loaded in the prolog\n\ + iterator_A.load(tb_frag_A);\n\ + iterator_B0.load(tb_frag_B0);\n\ +\n\ + ++iterator_A;\n\ + ++iterator_B0;\n\ +\n\ + WarpFragmentA0 warp_frag_A0;\n\ + WarpFragmentB0 warp_frag_B0;\n\ +\n\ + Operator0 warp_mma0;\n\ +\n\ + // Avoid reading out of bounds\n\ + if (gemm_k_iterations_0 <= 1) {\n\ + iterator_A.clear_mask();\n\ + iterator_B0.clear_mask();\n\ + }\n\ +\n\ + // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing \n\ + // shared memory loads (which have the tightest latency requirement).\n\ +\n\ + //\n\ + // Mainloop\n\ + //\n\ +\n\ + // Note: The main loop does not support Base::WarpGemmIterations == 2.\n\ + CUTLASS_GEMM_LOOP\n\ + for (; gemm_k_iterations_0 > 0; --gemm_k_iterations_0) {\n\ +\n\ + this->smem_iterator_A_.store(tb_frag_A);\n\ + this->smem_iterator_B0_.store(tb_frag_B0);\n\ +\n\ + __syncthreads();\n\ + //\n\ + // Loop over GEMM K dimension\n\ + //\n\ +\n\ + CUTLASS_PRAGMA_UNROLL\n\ + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0; ++warp_mma_k) {\n\ +\n\ + // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group\n\ + // as the case may be.\n\ +\n\ + this->warp_tile_iterator_A0_.set_kgroup_index(warp_mma_k % Base::kWarpGemmIterations0);\n\ + this->warp_tile_iterator_B0_.set_kgroup_index(warp_mma_k % Base::kWarpGemmIterations0);\n\ +\n\ + this->warp_tile_iterator_A0_.load(warp_frag_A0);\n\ + this->warp_tile_iterator_B0_.load(warp_frag_B0);\n\ +\n\ + ++this->warp_tile_iterator_A0_;\n\ + ++this->warp_tile_iterator_B0_;\n\ +\n\ + warp_mma0(accum0, warp_frag_A0, warp_frag_B0, accum0);\n\ + }\n\ + this->warp_tile_iterator_A0_.add_tile_offset({0, -Policy0::kPartitionsK * Base::kWarpGemmIterations0});\n\ + this->warp_tile_iterator_B0_.add_tile_offset({-Policy0::kPartitionsK * Base::kWarpGemmIterations0, 0});\n\ +\n\ + __syncthreads();\n\ + iterator_A.load(tb_frag_A);\n\ + iterator_B0.load(tb_frag_B0);\n\ +\n\ + ++iterator_A;\n\ + ++iterator_B0;\n\ +\n\ + if(gemm_k_iterations_0 <= 2) {\n\ + iterator_A.clear_mask();\n\ + iterator_B0.clear_mask();\n\ + }\n\ + }\n" + + return accu_code + code + + + def gen_first_gemm_2stage(b2b_num): + + accu_code = " FragmentC0 accum0 = src_accum;\n" + if b2b_num == 1: + accu_code = " accum0 = src_accum;\n" + + code ="\ +\n\ + FragmentA0 tb_frag_A;\n\ + FragmentB0 tb_frag_B0;\n\ +\n\ + tb_frag_A.clear();\n\ + tb_frag_B0.clear();\n\ +\n\ + // The last kblock is loaded in the prolog\n\ + iterator_A.load(tb_frag_A);\n\ + iterator_B0.load(tb_frag_B0);\n\ +\n\ + ++iterator_A;\n\ + ++iterator_B0;\n\ +\n\ + this->smem_iterator_A_.store(tb_frag_A);\n\ + this->smem_iterator_B0_.store(tb_frag_B0);\n\ +\n\ + ++this->smem_iterator_A_;\n\ + ++this->smem_iterator_B0_;\n\ +\n\ + __syncthreads();\n\ +\n\ + // Pair of fragments used to overlap shared memory loads and math instructions\n\ + WarpFragmentA0 warp_frag_A0[2];\n\ + WarpFragmentB0 warp_frag_B0[2];\n\ +\n\ + this->warp_tile_iterator_A0_.set_kgroup_index(0);\n\ + this->warp_tile_iterator_B0_.set_kgroup_index(0);\n\ +\n\ + this->warp_tile_iterator_A0_.load(warp_frag_A0[0]);\n\ + this->warp_tile_iterator_B0_.load(warp_frag_B0[0]);\n\ +\n\ + ++this->warp_tile_iterator_A0_;\n\ + ++this->warp_tile_iterator_B0_;\n\ +\n\ + Operator0 warp_mma0;\n\ +\n\ + int smem_write_stage_idx = 1;\n\ +\n\ + // Avoid reading out of bounds\n\ + if (gemm_k_iterations_0 <= 1) {\n\ + iterator_A.clear_mask();\n\ + iterator_B0.clear_mask();\n\ + }\n\ +\n\ + // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing \n\ + // shared memory loads (which have the tightest latency requirement).\n\ + iterator_A.load(tb_frag_A);\n\ +\n\ + //\n\ + // Mainloop\n\ + //\n\ +\n\ + // Note: The main loop does not support Base::WarpGemmIterations == 2.\n\ + CUTLASS_GEMM_LOOP\n\ + for (; gemm_k_iterations_0 > 0; --gemm_k_iterations_0) {\n\ +\n\ + //\n\ + // Loop over GEMM K dimension\n\ + //\n\ +\n\ + CUTLASS_PRAGMA_UNROLL\n\ + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations0; ++warp_mma_k) {\n\ +\n\ + // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group\n\ + // as the case may be.\n\ +\n\ + if (warp_mma_k == Base::kWarpGemmIterations0 - 1) {\n\ +\n\ + // Write fragments to shared memory\n\ + this->smem_iterator_A_.store(tb_frag_A);\n\ +\n\ + this->smem_iterator_B0_.store(tb_frag_B0);\n\ +\n\ + __syncthreads();\n\ +\n\ + // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing \n\ + // shared memory loads (which have the tightest latency requirement).\n\ + iterator_A.load(tb_frag_A);\n\ + \n\ + ++this->smem_iterator_B0_;\n\ + ++this->smem_iterator_A_;\n\ + \n\ +\n\ + // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory\n\ + if (smem_write_stage_idx == 1) {\n\ + this->smem_iterator_A_.add_tile_offset({0, -Base::Stage0});\n\ + this->smem_iterator_B0_.add_tile_offset({-Base::Stage0, 0});\n\ + }\n\ + else {\n\ + this->warp_tile_iterator_A0_.add_tile_offset(\n\ + {0, -Base::Stage0 * Policy0::kPartitionsK * Base::kWarpGemmIterations0});\n\ + this->warp_tile_iterator_B0_.add_tile_offset(\n\ + {-Base::Stage0 * Policy0::kPartitionsK * Base::kWarpGemmIterations0,\n\ + 0});\n\ + }\n\ +\n\ + smem_write_stage_idx ^= 1;\n\ + }\n\ +\n\ + this->warp_tile_iterator_A0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0);\n\ + this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations0);\n\ + \n\ + this->warp_tile_iterator_A0_.load(warp_frag_A0[(warp_mma_k + 1) % 2]);\n\ + this->warp_tile_iterator_B0_.load(warp_frag_B0[(warp_mma_k + 1) % 2]);\n\ +\n\ + ++this->warp_tile_iterator_A0_;\n\ + ++this->warp_tile_iterator_B0_;\n\ +\n\ + if (warp_mma_k == 0) {\n\ +\n\ + iterator_B0.load(tb_frag_B0);\n\ +\n\ + ++iterator_A;\n\ + ++iterator_B0;\n\ +\n\ + // Avoid reading out of bounds if this was the last loop iteration\n\ + if (gemm_k_iterations_0 <= 2) {\n\ + iterator_A.clear_mask();\n\ + iterator_B0.clear_mask();\n\ + }\n\ + }\n\ +\n\ + warp_mma0(accum0, warp_frag_A0[warp_mma_k % 2], warp_frag_B0[warp_mma_k % 2], accum0);\n\ + }\n\ + }\n" + return accu_code + code + + def gen_other_gemms_2stage(b2b_num): + + code = "" + + def gemm_teamplate(id): + code = "// " + str(id + 1) + " Gemm" + code += " /// Iterator to load a warp-scoped tile of A1 operand from intermediate accumulator tile\n" + + code += " " + helper.var_idx("FragmentC", id - 1) + helper.var_idx(" after_epilogue_accu", id - 1) + ";\n" + code += " " + helper.var_idx("epilogue_", id - 1) + helper.var_idx("(output_op_", id - 1) + helper.var_idx(", accum", id - 1) \ + + helper.var_idx(", after_epilogue_accu", id - 1) + helper.var_idx(", iterator_C", id - 1) +");\n" + + # FragmentIteratorA1 warp_tile_iterator_A1_(accum0); + code += " " + helper.var_idx("FragmentIteratorA", id) + helper.var_idx(" warp_tile_iterator_A", id) +"_(" + helper.var_idx("after_epilogue_accu", id - 1) + ");\n" + # FragmentB1 tb_frag_B1; + code += " " + helper.var_idx("FragmentB", id) + " " + helper.var_idx("tb_frag_B", id) + ";\n" + # tb_frag_B1.clear(); + code += " " + helper.var_idx("tb_frag_B", id) + ".clear();\n" + # iterator_B1.load(tb_frag_B1); + code += " " + helper.var_idx("iterator_B", id) + ".load(" + helper.var_idx("tb_frag_B", id) + ");\n" + # ++iterator_B1; + code += " " + "++" + helper.var_idx("iterator_B", id) + ";\n" + # this->smem_iterator_B1_.store(tb_frag_B1); + code += " " + helper.var_idx("this->smem_iterator_B", id) + "_.store(" + helper.var_idx("tb_frag_B", id) + ");\n" + # ++this->smem_iterator_B1_; + code += " " + helper.var_idx("++this->smem_iterator_B", id) + "_;\n" + # __syncthreads(); + code += " " + "__syncthreads();\n" + # WarpFragmentA1 warp_frag_A1[2]; + code += " " + helper.var_idx("WarpFragmentA", id) + helper.var_idx(" warp_frag_A", id) + "[2];\n" + # WarpFragmentB1 warp_frag_B1[2]; + code += " " + helper.var_idx("WarpFragmentB", id) + helper.var_idx(" warp_frag_B", id) + "[2];\n" + # this->warp_tile_iterator_B1_.set_kgroup_index(0); + code += " " + helper.var_idx("this->warp_tile_iterator_B", id) + "_.set_kgroup_index(0);\n" + # warp_tile_iterator_A1_.load(warp_frag_A1[0], output_op_0); + code += " " + helper.var_idx("warp_tile_iterator_A", id) + helper.var_idx("_.load(warp_frag_A", id) + "[0]);\n" + # this->warp_tile_iterator_B1_.load(warp_frag_B1[0]); + code += " " + helper.var_idx("this->warp_tile_iterator_B", id) + helper.var_idx("_.load(warp_frag_B", id) + "[0]);\n" + # ++warp_tile_iterator_A1_; + code += " " + helper.var_idx("++warp_tile_iterator_A", id) + "_;\n" + # ++this->warp_tile_iterator_B1_; + code += " " + helper.var_idx("++this->warp_tile_iterator_B", id) + "_;\n" + # Operator1 warp_mma1; + code += " " + helper.var_idx("Operator", id) + " " + helper.var_idx("warp_mma", id) + ";\n" + # smem_write_stage_idx = 1; + code += " " + "smem_write_stage_idx = 1;\n" + # int gemm_k_iterations_1 = FragmentIteratorA1::Policy::kIterations / Base::kWarpGemmIterations1; + code += " " + helper.var_idx("int gemm_k_iterations_", id) + " = " + helper.var_idx("FragmentIteratorA", id) + helper.var_idx("::Policy::kIterations / Base::kWarpGemmIterations", id) +";\n" + # if (gemm_k_iterations_1 <= 1) { + # iterator_B1.clear_mask(); + # } + code += " " + "if (" + helper.var_idx("gemm_k_iterations_", id) + " <= 1 ){\n" \ + + " " + " " + helper.var_idx("iterator_B", id) + ".clear_mask();\n" \ + + " " +"}\n" + # CUTLASS_PRAGMA_UNROLL + code += " " + "CUTLASS_PRAGMA_UNROLL\n" + # for (; gemm_k_iterations_1 > 0; --gemm_k_iterations_1) { + code += " " + helper.var_idx("for (; gemm_k_iterations_", id) + helper.var_idx(" > 0; --gemm_k_iterations_", id) + ") {\n" + # CUTLASS_PRAGMA_UNROLL + code += " " + " " + "CUTLASS_PRAGMA_UNROLL\n" + # for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations1; ++warp_mma_k) { + code += " " + " " + helper.var_idx("for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations", id) + "; ++warp_mma_k) {\n" + # if (warp_mma_k == Base::kWarpGemmIterations1 - 1) { + code += " " + " " + " " + helper.var_idx("if (warp_mma_k == Base::kWarpGemmIterations", id) + " - 1) {\n" + # this->smem_iterator_B1_.store(tb_frag_B1); + code += " " + " " + " " + " " + helper.var_idx(" this->smem_iterator_B", id) + helper.var_idx("_.store(tb_frag_B", id) + ");\n" + # __syncthreads(); + code += " " + " " + " " + " " + "__syncthreads();\n" + # ++smem_iterator_B1_; + code += " " + " " + " " + " " + helper.var_idx(" ++smem_iterator_B", id) + "_;\n" + # if (smem_write_stage_idx == 1) { + # smem_iterator_B1_.add_tile_offset({-Base::Stage, 0}); + # } + code += " " + " " + " " + " " + "if ( smem_write_stage_idx == 1 ) {\n" \ + + " " + " " + " " + " " + " " + helper.var_idx("smem_iterator_B", id) + helper.var_idx("_.add_tile_offset({-Base::Stage", i) + ", 0});\n" \ + + " " + " " + " " + " " +"}\n" + # else { + # this->warp_tile_iterator_B1_.add_tile_offset( + # {-Base::Stage * Policy1::kPartitionsK * + # Base::kWarpGemmIterations1, + # 0}); + # } + code += " " + " " + " " + " " + "else {\n" \ + + " " + " " + " " + " " + " " + helper.var_idx("this->warp_tile_iterator_B", id) + "_.add_tile_offset(\n" \ + + " " + " " + " " + " " + " " + helper.var_idx("{-Base::Stage", id) + helper.var_idx(" * Policy", id) + "::kPartitionsK *\n" \ + + " " + " " + " " + " " + " " + helper.var_idx("Base::kWarpGemmIterations", id) + ",\n" \ + + " " + " " + " " + " " + " " + "0});\n" \ + + " " + " " + " " + " " + "}\n" + + # smem_write_stage_idx ^= 1; + # } + code += " " + " " + " " + " " + "smem_write_stage_idx ^= 1;\n" \ + + " " + " " + " " + "}\n" + + # this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations1); + code += " " + " " + " " + helper.var_idx("this->warp_tile_iterator_B", id) + helper.var_idx("_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations", id) + ");\n" + # warp_tile_iterator_A1_.load(warp_frag_A1[(warp_mma_k + 1) % 2], output_op_0); + code += " " + " " + " " + helper.var_idx("warp_tile_iterator_A", id) + helper.var_idx("_.load(warp_frag_A", id) + "[(warp_mma_k + 1) % 2]);\n" + # this->warp_tile_iterator_B1_.load(warp_frag_B1[(warp_mma_k + 1) % 2]); + code += " " + " " + " " + helper.var_idx("this->warp_tile_iterator_B", id) + helper.var_idx("_.load(warp_frag_B", id) + "[(warp_mma_k + 1) % 2]);\n" + # ++warp_tile_iterator_A1_; + code += " " + " " + " " + helper.var_idx("++warp_tile_iterator_A", id) + "_;\n" + # ++this->warp_tile_iterator_B1_; + code += " " + " " + " " + helper.var_idx("++this->warp_tile_iterator_B", id) + "_;\n" + # if (warp_mma_k == 0) { + # iterator_B1.load(tb_frag_B1); + # ++iterator_B1; + # if (gemm_k_iterations_1 <= 2) { + # iterator_B1.clear_mask(); + # } + # } + code += " " + " " + " " + " if (warp_mma_k == 0) {\n" \ + + " " + " " + " " + " " + helper.var_idx("iterator_B", id) + helper.var_idx(".load(tb_frag_B", id) + ");\n" \ + + " " + " " + " " + " " + helper.var_idx("++iterator_B", id) +";\n" \ + + " " + " " + " " + " " + helper.var_idx("if (gemm_k_iterations_", id) +" <= 2) {\n" \ + + " " + " " + " " + " " + " " + helper.var_idx("iterator_B", id) + ".clear_mask();\n" \ + + " " + " " + " " + " " + "}\n" \ + + " " + " " + " " + "}\n" + # warp_mma1(accum, warp_frag_A1[warp_mma_k % 2], warp_frag_B1[warp_mma_k % 2], accum); + # } + # } + code += " " + " " + " " + helper.var_idx("warp_mma", id) + helper.var_idx("(accum", id) + helper.var_idx(", warp_frag_A", id) + helper.var_idx("[warp_mma_k % 2], warp_frag_B", id) + helper.var_idx("[warp_mma_k % 2], accum", id) + ");\n" \ + + " " + " " + "}\n" \ + + " " + "}\n\n\n" + + return code + + for i in range (1, b2b_num): + clear_accu = "" + if i != b2b_num - 1: + clear_accu = " " + helper.var_idx("FragmentC", i) + helper.var_idx(" accum", i) +";\n" + clear_accu += " " + helper.var_idx("accum", i) +".clear();\n" + code += clear_accu + gemm_teamplate(i) + + return code + + operator_code = " CUTLASS_DEVICE\n\ + void operator()(\n " + gen_operator_param(self.b2b_num) + ") {\n" + if first_use_1stage: + operator_code += gen_first_gemm_1stage(self.b2b_num) + else: + operator_code += gen_first_gemm_2stage(self.b2b_num) + operator_code += gen_other_gemms_2stage(self.b2b_num) + "}\n" + return operator_code + + def gen_construct_func(self): + name = self.gen_class_name + func_code = "CUTLASS_DEVICE\n" + func_code += name + "(\n" \ + + " " + "typename Base::B2bMmaSharedStorage &shared_storage,\n" \ + + " " + "int thread_idx,\n" \ + + " " + "int warp_idx,\n" \ + + " " + "int lane_idx\n" \ + + "):\n" + func_code += " " + "Base(shared_storage, thread_idx, warp_idx, lane_idx),\n" \ + + " " + "smem_iterator_A_(shared_storage.sharedStorage0.operand_A_ref(), thread_idx),\n" + + for i in range(self.b2b_num): + final = ",\n" + if i == self.b2b_num - 1: + final = " {\n" + func_code += helper.var_idx("smem_iterator_B", i) + helper.var_idx("_(shared_storage.sharedStorage", i) +".operand_B_ref(), thread_idx)" + final + + func_code += " " + "int warp_idx_mn = warp_idx % (Base::WarpCount0::kM * Base::WarpCount0::kN);\n" + func_code += " " + "int warp_idx_k = warp_idx / (Base::WarpCount0::kM * Base::WarpCount0::kN);\n" + + func_code += " " + "int warp_idx_m = warp_idx_mn % Base::WarpCount0::kM;\n" + func_code += " " + "int warp_idx_n = warp_idx_mn / Base::WarpCount0::kM;\n" + + for i in range(self.b2b_num): + func_code += " " + helper.var_idx("int tile_offset_k", i) + helper.var_idx(" = Base::kWarpGemmIterations", i) + " * warp_idx_k;\n" + + func_code += " " + "this->warp_tile_iterator_A0_.add_tile_offset({warp_idx_m, tile_offset_k0});\n" + + for i in range(self.b2b_num): + func_code += " " + helper.var_idx("this->warp_tile_iterator_B", i) + helper.var_idx("_.add_tile_offset({tile_offset_k", i) + ", warp_idx_n});\n" + + func_code += "}\n" + + return func_code + + def gen_member_func(self, first_use_1stage): + code = "public:\n" + code += self.gen_operator(first_use_1stage) + code += self.gen_construct_func() + + return code + + def gen_code(self, first_use_1stage): + + def gen_template_args(b2b_num): + template_param = [] + template_param.append(("typename", "Shape0")) + template_param.append(("typename", "IteratorA0")) + template_param.append(("typename", "SmemIteratorA0")) + template_param.append(("typename", "IteratorB0")) + template_param.append(("typename", "SmemIteratorB0")) + + for i in range(1, b2b_num): + template_param.append(("typename", helper.var_idx("Shape", i))) + template_param.append(("typename", helper.var_idx("FragmentIteratorA", i))) + template_param.append(("typename", helper.var_idx("IteratorB", i))) + template_param.append(("typename", helper.var_idx("SmemIteratorB", i))) + + template_param.append(("typename", "ElementC")) + template_param.append(("typename", "LayoutC")) + + for i in range(0, b2b_num - 1): + template_param.append(("typename", helper.var_idx("OutputOp", i))) + + for i in range(0, b2b_num - 1): + template_param.append(("typename", helper.var_idx("FusedAddBiasEpilogue", i))) + + for i in range(0, b2b_num): + template_param.append(("typename", helper.var_idx("Policy", i))) + for i in range(0, b2b_num): + template_param.append((int, helper.var_idx("Stage", i))) + + template_param.append(("typename","TransformA0", "NumericArrayConverter")) + + for i in range(0, b2b_num): + cvtr = helper.var_idx("NumericArrayConverter" + template_param.append(("typename", helper.var_idx("TransformB", i), cvtr)) + + template_param.append(("typename", "Enable", "bool")) + + return template_param + + template_param = gen_template_args(self.b2b_num) + inheritance_code = "public B2bMmaBase<" + for i in range(self.b2b_num): + inheritance_code += helper.var_idx("Shape", i) + "_, " + for i in range(self.b2b_num): + inheritance_code += helper.var_idx("Policy", i) + "_, " + for i in range(self.b2b_num - 1): + inheritance_code += helper.var_idx("Stage", i) + "_, " + inheritance_code += helper.var_idx("Stage", self.b2b_num - 1) + "_" + inheritance_code += ">" + + code_body = "" + using_code= self.gen_using() + func_code = self.gen_member_func(first_use_1stage) + + code_body = using_code + func_code + + class_code = gen_ir.gen_template_class(self.gen_class_name, template_param, code_body, inheritance_code = inheritance_code) + + code = self.gen_include_header() + code += gen_ir.gen_namespace("cutlass", gen_ir.gen_namespace("gemm", gen_ir.gen_namespace("threadblock", class_code))) + # print(code) + return code + + +class gen_b2b_mma_base: + def __init__(self, template_param, gen_class_name, b2b_num, cutlass_deps_root, project_root): + self.gen_class_name = gen_class_name + self.template_param = template_param + self.b2b_num = b2b_num + self.cutlass_deps_root = cutlass_deps_root + self.project_root = project_root + + def gen_include_header(self): + code = ''' +#pragma once + +#include \"{cutlass_dirs}cutlass/aligned_buffer.h\" +#include \"{cutlass_dirs}cutlass/arch/memory.h\" +#include \"{cutlass_dirs}cutlass/array.h\" +#include \"{cutlass_dirs}cutlass/cutlass.h\" +#include \"{cutlass_dirs}cutlass/gemm/gemm.h\" +#include \"{cutlass_dirs}cutlass/matrix_shape.h\" +#include \"{cutlass_dirs}cutlass/numeric_types.h\"\n'''.format(cutlass_dirs=self.cutlass_deps_root) + return code + + def gen_shared_storage(self): + code = \ +" template< \n\ + typename Shape_,\n\ + typename Policy_,\n\ + int ThisStage_\n\ +>\n\ +class SharedStorage {\n\ +public:\n\ + using Shape = Shape_;\n\ + using Policy = Policy_;\n\ + static int const ThisStage = ThisStage_;\n\ + using Operator = typename Policy::Operator;\n\ + \ + using TensorRefA = TensorRef;\n\ + \ + /// Tensor reference to the B operand \n\ + using TensorRefB = TensorRef;\n\ +\n\ + /// Shape of the A matrix operand in shared memory \n\ + using ShapeA = MatrixShape;\n\ +\n\ + /// Shape of the B matrix operand in shared memory\n\ + using ShapeB =\n\ + MatrixShape;\n\ +\n\ + public:\n\ +\n\ + /// Buffer for A operand\n\ + AlignedBuffer operand_A;\n\ +\n\ + /// Buffer for B operand\n\ + AlignedBuffer operand_B;\n\ +\n\ + public:\n\ +\n\ + /// Returns a layout object for the A matrix\n\ + CUTLASS_DEVICE\n\ + static typename Operator::LayoutA LayoutA() {\n\ + return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn});\n\ + }\n\ +\n\ + /// Returns a layout object for the B matrix\n\ + CUTLASS_HOST_DEVICE\n\ + static typename Operator::LayoutB LayoutB() {\n\ + return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn});\n\ + }\n\ +\n\ + /// Returns a TensorRef to the A operand\n\ + CUTLASS_HOST_DEVICE\n\ + TensorRefA operand_A_ref() {\n\ + return TensorRefA{operand_A.data(), LayoutA()};\n\ + }\n\ +\n\ + /// Returns a TensorRef to the B operand\n\ + CUTLASS_HOST_DEVICE\n\ + TensorRefB operand_B_ref() {\n\ + return TensorRefB{operand_B.data(), LayoutB()};\n\ + }\n\ + CUTLASS_HOST_DEVICE\n\ + void * get_B_Shared_ptr() {\n\ + return operand_B.data();\n\ + }\n\ + };\n" + return code + + def gen_using_and_misc(self, b2b_num): + code_using = "" + for i in range(b2b_num): + code_using += "using Operator" +str(i) + " = typename Policy" + str(i) +"::Operator;\n" + + for i in range(b2b_num): + code_using += "using WarpGemm" +str(i) + " = typename Policy" + str(i) +"::Operator::Shape;\n" + + for i in range(b2b_num): + code_using += "using WarpCount" +str(i) + " = GemmShape<" + helper.var_idx("Shape", i) +"::kM / " + helper.var_idx("WarpGemm", i) +"::kM, "\ + + helper.var_idx("Shape", i) +"::kN / " + helper.var_idx("WarpGemm", i) +"::kN, "\ + + helper.var_idx("Shape", i) +"::kK / " + helper.var_idx("WarpGemm", i) +"::kK>;\n" + + code_misc = "" + for i in range(b2b_num): + code_misc += "static int const " + helper.var_idx("kWarpGemmIterations", i) + " = (" + helper.var_idx("WarpGemm", i) + "::kK / " + helper.var_idx("Operator", i) +"::Policy::MmaShape::kK);\n" + + code = code_using + code_misc + self.gen_shared_storage() + + for i in range(b2b_num): + code += "using " + helper.var_idx("SharedStorage", i) + " = SharedStorage<" + helper.var_idx("Shape", i) + ", " + helper.var_idx("Policy", i) +", " + helper.var_idx("Stage", i) + ">;\n" + + def gen_union_shared_storage(b2b_num): + code = "" + for i in range(b2b_num): + code += " " +helper.var_idx("SharedStorage", i) + " " + helper.var_idx("sharedStorage", i) +";\n" + return code + + code += "union B2bMmaSharedStorage {\n" + gen_union_shared_storage(self.b2b_num) + "};\n" + + for i in range(b2b_num - 1): + code += helper.var_idx("void * C", i) + "_smm_ptr;\n" + + return code + + def gen_protected(self): + code = "\nprotected:\n" + code += "typename Operator0::IteratorA warp_tile_iterator_A0_;\n" + for i in range(self.b2b_num): + code += "typename Operator" +str(i) + "::IteratorB" +" warp_tile_iterator_B" + str(i) + "_;\n" + return code + + def gen_public_member(self): + code = "\npublic:\n" + + code += "CUTLASS_DEVICE\n" + code += \ + "B2bMmaBase(\n" + \ + " B2bMmaSharedStorage & shared_storage,\n" + \ + " int thread_idx,\n" + \ + " int warp_idx,\n" + \ + " int lane_idx\n" + \ + "):\n" + \ + " warp_tile_iterator_A0_(shared_storage.sharedStorage0.operand_A_ref(), lane_idx),\n" + for i in range(self.b2b_num): + final = ",\n" + if i == self.b2b_num-1: + final = "\n" + + iterator = " warp_tile_iterator_B" + str(i) + "_" + shared_storage = "shared_storage.sharedStorage" + str(i) + ".operand_B_ref()" + code += iterator + "(" + shared_storage + ", lane_idx)" + final + + + code += "{\n" + for i in range(self.b2b_num - 1): + code += helper.var_idx(" C", i) + helper.var_idx("_smm_ptr = shared_storage.sharedStorage", i) + ".get_B_Shared_ptr();\n" + code += "}\n" + + return code + + def gen_code(self): + + template_arg = [] + for i in range(self.b2b_num): + template_arg.append(("typename", helper.var_idx("Shape", i))) + for i in range(self.b2b_num): + template_arg.append(("typename", helper.var_idx("Policy", i))) + for i in range(self.b2b_num): + template_arg.append((int, helper.var_idx("Stage", i))) + + + + code_body = self.gen_using_and_misc(self.b2b_num) + code_body += self.gen_protected() + code_body += self.gen_public_member() + + class_code = gen_ir.gen_template_class("B2bMmaBase", template_arg, code_body) + + code = self.gen_include_header() + gen_ir.gen_namespace("cutlass", gen_ir.gen_namespace("gemm", gen_ir.gen_namespace("threadblock", class_code))) + + return code + + +class gen_threadblock: + def __init__(self, template_param, gen_class_name, b2b_num, output_dir, cutlass_deps_root, project_root): + self.gen_class_name = gen_class_name + self.template_param = template_param + self.b2b_num = b2b_num + self.file_dir = output_dir + "/threadblock/" + + self.cutlass_deps_root = cutlass_deps_root + self.project_root = project_root + + + self.gen_b2b_mma_base = gen_b2b_mma_base(template_param, gen_class_name, b2b_num, cutlass_deps_root, project_root) + self.gen_b2b_mma_pipelined = gen_b2b_mme_pipelined(template_param, gen_class_name, b2b_num, cutlass_deps_root, project_root) + self.gen_default_b2b_mma = gen_default_b2b_mma(template_param, gen_class_name, b2b_num, cutlass_deps_root, project_root) + + + def gen_code(self, first_use_1stage): + + base_code = self.gen_b2b_mma_base.gen_code() + print("[INFO]: Gen kernel code [b2b_mma_base.h]output Dir: is ", self.file_dir) + + with open(self.file_dir + "b2b_mma_base.h", "w+") as f: + f.write(base_code) + pipeline_code = self.gen_b2b_mma_pipelined.gen_code(first_use_1stage = first_use_1stage) + print("[INFO]: Gen kernel code [b2b_mma_pipelined.h]output Dir: is ", self.file_dir) + + with open(self.file_dir + "b2b_mma_pipelined.h", "w+") as f: + f.write(pipeline_code) + default_code = self.gen_default_b2b_mma.gen_code() + print("[INFO]: Gen kernel code [default_b2b_mma.h]output Dir: is ", self.file_dir) + + with open(self.file_dir + "default_b2b_mma.h", "w+") as f: + f.write(default_code) diff --git a/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_turing_and_volta.py b/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_turing_and_volta.py new file mode 100644 index 0000000000..db1ec4c72f --- /dev/null +++ b/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_turing_and_volta.py @@ -0,0 +1,456 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import helper +import gen_ir as ir + +class gen_turing_impl: + def __init__(self,fuse_gemm_info, gen_class_name, user_header_file, output_dir = "../"): + self.fuse_gemm_info = fuse_gemm_info + self.class_name = gen_class_name + self.gen_class_name = gen_class_name + "_turing_impl" + self.user_header_file = "" + for header in user_header_file: + self.user_header_file += "#include \"" + header + "\"\n" + self.output_dir = output_dir + self.b2b_num = len(fuse_gemm_info) + + self.gen_turing_unfused = gen_volta_turing_fuse_act_impl(fuse_gemm_info, gen_class_name, user_header_file, output_dir) + + def gen_using(self): + code_using = "using b2b_gemm = typename cutlass::gemm::device::" + self.class_name + ";" + + return code_using + "\n" + + def gen_initialize(self): + code = "" + for i in range(self.b2b_num): + code_this = "" + + code_this += helper.var_idx(helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + " alpha", i) + " = " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + "(1);\n" + beta = "(1)" + + if helper.get_epilogue_add_bias_or_not(self.fuse_gemm_info[i]) is False: + beta = "(0)" + code_this += helper.var_idx(helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + " beta", i) + " = " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + beta + ";\n" + k_str = str(self.fuse_gemm_info[i]['mnk'][2]) + if i == 0: + k_str = "K0" + code_this += helper.var_idx("cutlass::gemm::GemmCoord problem_size_", i) + "(M, " + str(self.fuse_gemm_info[i]['mnk'][1]) + ", " + k_str + ");\n" + code += code_this + code += "typename b2b_gemm::Arguments arguments{\n" + + for i in range(self.b2b_num): + code += " " + helper.var_idx("problem_size_", i) + ",\n" + + + code += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['A_tp']) + "*>(" + helper.var_idx("A", 0) + "), " + helper.var_idx("problem_size_", 0) + ".k()},\n" + + for i in range(self.b2b_num): + + ldmB = str(self.fuse_gemm_info[i]['mnk'][2]) + if i == 0: + ldmB = "K0" + + if self.fuse_gemm_info[i]['B_format'] is 'Row': + ldmB = str(self.fuse_gemm_info[i]['mnk'][1]) + + ldmC = str(helper.get_epilogue_bias_ldm(self.fuse_gemm_info[i])) + + code += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['B_tp']) + "*>(" + helper.var_idx("B", i) + "), " + ldmB + "},\n" + code += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_tp']) + "*>(" + helper.var_idx("C", i) + "), " + ldmC + "},\n" + code += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_tp']) + "*>(" + helper.var_idx("D", self.b2b_num -1) + "), " + helper.var_idx("problem_size_", self.b2b_num - 1) + ".n()},\n" + + + for i in range(self.b2b_num): + code += " " + "{ " + helper.var_idx("alpha", i) + ", " + helper.var_idx("beta", i) + for epilogue_arg in helper.get_epilogue_args(self.fuse_gemm_info[i]): + arg_name = helper.var_idx("Epilogue", i) + "_" + epilogue_arg[1] + code += ", " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + "(" + str(arg_name) + ")" + code += "},\n" + code += " " + "Batch};\n\n" + + code += " " "b2b_gemm gemm_op;\n" + code += " " + "gemm_op.initialize(arguments);\n" + return code + "\n" + + + + def gen_run(self): + code = " " + "gemm_op(stream);\n" + + return code + + def gen_wrapper(self): + code_body = "" + + arg_lists = [] + arg_lists.append(["int", "M"]) + arg_lists.append(["int", "K0"]) + arg_lists.append(["int", "Batch"]) + arg_lists.append(["void*", helper.var_idx("A", 0)]) + for i in range(self.b2b_num): + arg_lists.append(["void*", helper.var_idx("B", i)]) + arg_lists.append(["void*", helper.var_idx("C", i)]) + arg_lists.append(["void*", helper.var_idx("D", i)]) + epilogue_args = helper.get_epilogue_args(self.fuse_gemm_info[i]) + acc_tp = helper.get_epilogue_compute_tp(self.fuse_gemm_info[i]) + for arg in epilogue_args: + arg_tp = arg[0] + arg_name = helper.var_idx("Epilogue", i) + "_" + arg[1] + arg_lists.append([arg_tp, arg_name]) + + if self.b2b_num == 1: + code_body += self.gen_turing_unfused.gen_using(False) #False -> Turing, True -> Volta + code_body += self.gen_turing_unfused.gen_initialize() + code_body += self.gen_turing_unfused.gen_run() + else: + code_body += self.gen_using() + code_body += self.gen_initialize() + code_body += self.gen_run() + + code = ir.gen_func(self.gen_class_name, arg_lists, code_body) + + return code + + def gen_code(self): + + code = self.gen_wrapper() + helper.write_2_headfile("turing_impl.h", self.output_dir, self.user_header_file + "\n" + code) + +class gen_volta_turing_fuse_act_impl: + def __init__(self, fuse_gemm_info, gen_class_name, user_header_file, output_dir = "../"): + self.fuse_gemm_info = fuse_gemm_info + self.gen_class_name = gen_class_name + "_volta_impl" + self.user_header_file = "" + for header in user_header_file: + self.user_header_file += "#include \"" + header + "\"\n" + self.output_dir = output_dir + self.b2b_num = len(fuse_gemm_info) + + def perf_tiling(self, layer_mnk): + mnk = layer_mnk[:] + block_tile = mnk[:] + block_tile[2] = 32 # force the K tile to be 32 + + # M tile gen + block_tile[0] = 32 + + # N tile gen + if mnk[1] > 128: + block_tile[1] = 256 + elif mnk[1] > 64: + block_tile[1] = 128 + elif mnk[1] > 32: + block_tile[1] = 64 + else : + block_tile[1] = 32 + + warp_tile = block_tile[:] + if block_tile[1] == 256: + warp_tile[1] = 64 + elif block_tile[1] == 128: + warp_tile[1] = 32 + elif block_tile[1] == 64: + warp_tile[1] = 32 + else : + warp_tile[1] = 32 + + warp_tile[0] = 32 + + return block_tile, warp_tile + + + def process_epilogue(self, epilogue_tp, n, C_tp, Acc_tp): + epilogue_setted_type = epilogue_tp + cutlass_epilogue_name = "LinearCombinationRelu" + if epilogue_setted_type.lower() == 'leakyrelu': + cutlass_epilogue_name = "LinearCombinationLeakyRelu" + elif epilogue_setted_type.lower() == 'identity': + cutlass_epilogue_name = "LinearCombination" + + + n_mod_8 = n % 4 + N_align_elements = 1 + if n_mod_8 == 0: + N_align_elements = 8 + elif n_mod_8 == 4: + N_align_elements = 4 + elif n_mod_8 == 2 or n_mod_8 == 6: + N_align_elements = 2 + + epilogue_str = "cutlass::epilogue::thread::" + cutlass_epilogue_name+ "<" + C_tp + ", " + str(N_align_elements) + ", " + Acc_tp + ", " + Acc_tp + ">" + + return epilogue_str + + def gen_using(self, volta = True): + code_using = "" + volta_arch = "cutlass::arch::Sm70" + volta_tc = "cutlass::gemm::GemmShape<8, 8, 4>" + + turing_arch = "cutlass::arch::Sm75" + turing_tc = "cutlass::gemm::GemmShape<16, 8, 8>" + + arch = "" + tc = "" + if volta: + arch = volta_arch + tc = volta_tc + else: + arch = turing_arch + tc = turing_tc + + for i in range(self.b2b_num): + + k = self.fuse_gemm_info[i]['mnk'][2] + + k_mod_8 = k % 4 + ab_ldm = 1 + if k_mod_8 == 0: + ab_ldm = 8 + elif k_mod_8 == 4: + ab_ldm = 4 + elif k_mod_8 == 2 or k_mod_8 == 6: + ab_ldm = 2 + + block_tile, warp_tile = self.perf_tiling(self.fuse_gemm_info[i]['mnk']) + + this_gemm_config = helper.var_idx("using Gemm", i) + " = cutlass::gemm::device::GemmBatched<\n" + this_gemm_config += " " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['A_tp']) + ",\n" + this_gemm_config += " " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['A_format']) + ",\n" + this_gemm_config += " " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['B_tp']) + ",\n" + this_gemm_config += " " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['B_format']) + ",\n" + this_gemm_config += " " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_tp']) + ",\n" + this_gemm_config += " " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_format']) + ",\n" + this_gemm_config += " " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + ",\n" + this_gemm_config += " " + "cutlass::arch::OpClassTensorOp,\n" + this_gemm_config += " " + arch + ",\n" + this_gemm_config += " " + "cutlass::gemm::GemmShape<" + str(block_tile[0]) + ", " + str(block_tile[1]) + ", " + str(block_tile[2]) + ">,\n" + this_gemm_config += " " + "cutlass::gemm::GemmShape<" + str(warp_tile[0]) + ", " + str(warp_tile[1]) + ", " + str(warp_tile[2]) + ">,\n" + this_gemm_config += " " + tc + ",\n" + this_gemm_config += " " + self.process_epilogue(helper.get_epilogue_tp(self.fuse_gemm_info[i]), self.fuse_gemm_info[i]['mnk'][1], helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_tp']), helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp'])) + ",\n" + this_gemm_config += " " + "cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle,\n" + this_gemm_config += " " + "2,\n" + this_gemm_config += " " + str(ab_ldm) + ",\n" + this_gemm_config += " " + str(ab_ldm) + ">;\n" + + code_using += this_gemm_config + "\n" + + return code_using + "\n" + + def gen_initialize(self): + code = "" + for i in range(self.b2b_num): + code_this = "" + + N_str = str(self.fuse_gemm_info[i]['mnk'][1]) + + code_this += helper.var_idx(helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + " alpha", i) + " = " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + "(1);\n" + beta = "(1)" + if helper.get_epilogue_add_bias_or_not( self.fuse_gemm_info[i]) is False: + beta = "(0)" + code_this += helper.var_idx(helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + " beta", i) + " = " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + beta + ";\n" + + k_str = str(self.fuse_gemm_info[i]['mnk'][2]) + if i == 0: + k_str = "K0" + code_this += helper.var_idx("cutlass::gemm::GemmCoord problem_size_", i) + "(M, " + str(self.fuse_gemm_info[i]['mnk'][1]) + ", " + k_str + ");\n" + code_this += helper.var_idx("typename Gemm", i) + helper.var_idx("::Arguments arguments_", i) + "{\n" + code_this += " " + helper.var_idx("problem_size_", i) + ",\n" + ldmA = k_str + ldmB = k_str + ldmC = str(self.fuse_gemm_info[i]['mnk'][1]) + + ldmBias = str(helper.get_epilogue_bias_ldm(self.fuse_gemm_info[i])) + + if self.fuse_gemm_info[i]['A_format'] is 'Col': + ldmA = "M" + if self.fuse_gemm_info[i]['B_format'] is 'Row': + ldmB = str(self.fuse_gemm_info[i]['mnk'][1]) + if self.fuse_gemm_info[i]['C_format'] is 'Col': + ldmC = "M" + + if i == 0: + code_this += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['A_tp']) + "*>(" + helper.var_idx("A", i) + "), " + ldmA + "}, " + "M * " + ldmA + ",\n" + else: + code_this += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['A_tp']) + "*>(" + helper.var_idx("D", i - 1) + "), " + ldmA + "}, " + "M * " + ldmA + ",\n" + + code_this += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['B_tp']) + "*>(" + helper.var_idx("B", i) + "), " + ldmB + "}, " + N_str + " * " + ldmB + ",\n" + + M_bias = str(helper.get_epilogue_bias_shape(self.fuse_gemm_info[i])[0]) + + code_this += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_tp']) + "*>(" + helper.var_idx("C", i) + "), " + ldmBias + "}, " + M_bias + " * " + N_str + ",\n" + code_this += " " + "{reinterpret_cast<" + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['C_tp']) + "*>(" + helper.var_idx("D", i) + "), " + ldmC + "}, " + "M * " + ldmC + ",\n" + code_this += " " + "{ " + helper.var_idx("alpha", i) + ", " + helper.var_idx("beta", i) + for epilogue_arg in helper.get_epilogue_args(self.fuse_gemm_info[i]): + arg_name = helper.var_idx("Epilogue", i) + "_" + epilogue_arg[1] + code_this += ", " + helper.type_2_cutlass_type(self.fuse_gemm_info[i]['Acc_tp']) + "(" + str(arg_name) + ")" + code_this += " },\n" + code_this += " " + "Batch};\n" + + code_this += " " + helper.var_idx("Gemm", i) + helper.var_idx(" gemm_op_", i) + ";\n" + code_this += " " + helper.var_idx("gemm_op_", i) + helper.var_idx(".initialize(arguments_", i) + ", nullptr);\n" + + code += code_this + "\n" + return code + "\n" + + + def gen_run(self): + code = "" + for i in range(self.b2b_num): + code_this = "" + code_this += " " + helper.var_idx("gemm_op_", i) + "(stream);\n" + + code += code_this + return code + + def gen_wrapper(self): + code_body = "" + + arg_lists = [] + arg_lists.append(["int", "M"]) + arg_lists.append(["int", "K0"]) + arg_lists.append(["int", "Batch"]) + arg_lists.append(["void*", helper.var_idx("A", 0)]) + for i in range(self.b2b_num): + arg_lists.append(["void*", helper.var_idx("B", i)]) + arg_lists.append(["void*", helper.var_idx("C", i)]) + arg_lists.append(["void*", helper.var_idx("D", i)]) + epilogue_args = helper.get_epilogue_args(self.fuse_gemm_info[i]) + acc_tp = helper.get_epilogue_compute_tp(self.fuse_gemm_info[i]) + for arg in epilogue_args: + arg_tp = arg[0] + arg_name = helper.var_idx("Epilogue", i) + "_" + arg[1] + arg_lists.append([arg_tp, arg_name]) + code_body += self.gen_using() + code_body += self.gen_initialize() + code_body += self.gen_run() + + code = ir.gen_func(self.gen_class_name, arg_lists, code_body) + + return code + + def gen_code(self): + code = self.gen_wrapper() + helper.write_2_headfile("volta_impl.h", self.output_dir, self.user_header_file + "\n" + code) + +class gen_one_API: + def __init__(self, fuse_gemm_info, gen_class_name, user_header_file, output_dir = "../"): + self.fuse_gemm_info = fuse_gemm_info + self.gen_class_name = gen_class_name + self.user_header_file = "" + for header in user_header_file: + self.user_header_file += "#include \"" + header + "\"\n" + self.output_dir = output_dir + self.b2b_num = len(fuse_gemm_info) + + self.gen_volta = gen_volta_turing_fuse_act_impl(fuse_gemm_info, gen_class_name, user_header_file, output_dir) + + self.gen_turing = gen_turing_impl(fuse_gemm_info, gen_class_name, user_header_file, output_dir) + + def gen_CUTLASS_irrelevant_API(self): + code = "" + code += "#include \n" + code += "#include \n" + + param_name = "Fused" + str(self.b2b_num) + "xGemm_" + for i in range(self.b2b_num): + param_name += str(self.fuse_gemm_info[i]['mnk'][1]) + "_" + param_name += "Params" + params = "" + params += " " + "int M;\n" + params += " " + "int K0;\n" + params += " " + "int Batch;\n" + params += " " + "const void* A0;\n" + for i in range(self.b2b_num): + params += " " + "const void* " + helper.var_idx("B", i) + ";\n" + params += " " + "const void* " + helper.var_idx("C", i) + ";\n" + epilogue_args = helper.get_epilogue_args(self.fuse_gemm_info[i]) + acc_tp = helper.get_epilogue_compute_tp(self.fuse_gemm_info[i]) + for arg in epilogue_args: + arg_tp = arg[0] + arg_name = helper.var_idx("Epilogue", i) + "_" + arg[1] + params += " " + arg_tp + " " + arg_name + ";\n" + params += " " + "void* " + helper.var_idx("D", i) + ";\n" + code += ir.gen_struct(param_name, params) + code += "using Param = " + param_name + ";\n" + code += "void one_api( const Param & param, int sm, cudaStream_t stream);\n" + + + return code + + def gen_one_api(self): + code = "" + code += "/* Auto Generated code - Do not edit.*/\n" + code += "#include \"cutlass_irrelevant.h\"\n" + code += "#include \"api.h\"\n" + code += "void one_api( const Param & param, int sm, cudaStream_t stream) {\n" + + code += " " + "if (sm == 70) \n" + code += " " + " " + self.gen_class_name + "_volta_impl(param.M, param.K0, param.Batch, const_cast(param.A0), " + for i in range(self.b2b_num): + code += helper.var_idx("const_cast(param.B", i) + "), " + code += helper.var_idx("const_cast(param.C", i) + "), " + code += helper.var_idx("param.D", i) + ", " + epilogue_args = helper.get_epilogue_args(self.fuse_gemm_info[i]) + for arg in epilogue_args: + arg_name = helper.var_idx("Epilogue", i) + "_" + arg[1] + code += "param." + arg_name + ", " + code += "stream);\n" + code += " " + "else if(sm >= 75) \n" + code += " " + " " + self.gen_class_name + "_turing_impl(param.M, param.K0, param.Batch, const_cast(param.A0), " + for i in range(self.b2b_num): + code += helper.var_idx("const_cast(param.B", i) + "), " + code += helper.var_idx("const_cast(param.C", i) + "), " + code += helper.var_idx("param.D", i) + ", " + epilogue_args = helper.get_epilogue_args(self.fuse_gemm_info[i]) + for arg in epilogue_args: + arg_name = helper.var_idx("Epilogue", i) + "_" + arg[1] + code += "param." + arg_name + ", " + code += "stream);\n" + code += " " + "else assert(0);\n" + code += "}\n" + return code + + def gen_code(self): + + turing_code = self.gen_turing.gen_wrapper() + volta_code = self.gen_volta.gen_wrapper() + cutlass_irrelevant_code = self.gen_CUTLASS_irrelevant_API() + + one_api_code = self.gen_one_api() + with open(self.output_dir + "one_api.cu", "w+") as f: + f.write(one_api_code) + + helper.write_2_headfile("cutlass_irrelevant.h", self.output_dir, cutlass_irrelevant_code) + + helper.write_2_headfile("api.h", self.output_dir, self.user_header_file + "\n" + turing_code + volta_code) diff --git a/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_verify.py b/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_verify.py new file mode 100644 index 0000000000..44f3876588 --- /dev/null +++ b/examples/44_multi_gemm_ir_and_codegen/ir_gen/gen_verify.py @@ -0,0 +1,92 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import helper +import gen_ir as ir + +import gen_turing_and_volta as gen_basic + + +class gen_verify: + def __init__(self, fuse_gemm_info, gen_class_name, user_header_file, output_dir = "../"): + self.fuse_gemm_info = fuse_gemm_info + self.name = gen_class_name + "_verify" + self.b2b_num = len(fuse_gemm_info) + self.params = [] + self.user_header_file = "" + for header in user_header_file: + self.user_header_file += "#include \"" + header + "\"\n" + self.separate_cutlass = gen_basic.gen_volta_turing_fuse_act_impl(fuse_gemm_info, gen_class_name, user_header_file, output_dir) + self.gen_params() + self.output_dir = output_dir + + + def gen_code(self): + code = "" + code += self.user_header_file + code += self.separate_cutlass.gen_using(False) #False -> Turing, True -> Volta + + code_body = "" + for i in range(self.b2b_num): + code_body += " " + helper.var_idx("Gemm", i) + helper.var_idx(" gemm_op_", i) + ";\n" + code_body += " " + helper.var_idx("gemm_op_", i) + helper.var_idx(".initialize(Arguments_", i) + ", nullptr);\n" + + code_body += self.separate_cutlass.gen_run() + + code += ir.gen_func(self.name, self.params, code_body) + helper.write_2_headfile("cutlass_verify.h", self.output_dir, code) + + + def gen_params(self): + for i in range(self.b2b_num): + self.params.append( + ( + helper.var_idx("typename Gemm", i)+ "::Arguments", + helper.var_idx("Arguments_", i) + ) + ) + + + def get_params(self, declartion = True): + code = "" + if declartion: + for param in self.params: + code += param[0] + " " + param[1] + ";\n" + + return code + + + def gen_initialize(): + code = "" + initialize_code = self.separate_cutlass.gen_initialize() + + code = ir.gen_func("initialize", [[]]) diff --git a/examples/44_multi_gemm_ir_and_codegen/ir_gen/generate.sh b/examples/44_multi_gemm_ir_and_codegen/ir_gen/generate.sh new file mode 100755 index 0000000000..19d19ea937 --- /dev/null +++ b/examples/44_multi_gemm_ir_and_codegen/ir_gen/generate.sh @@ -0,0 +1,52 @@ +#!/bin/bash + +################################################################################################# +# +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +NUM_ARGS=3 +if [ $# -ne $NUM_ARGS ]; then + echo "Usage: $0 " + echo " config_file: JSON file containing configuration to run" + echo " output_directory: directory to store results" + echo " cutlass_directory: directory containing cutlass source" + exit 1 +fi + +config_file=$1 +output_dir=$2 +cutlass_dir=$3 + +python3 gen_all_code.py \ + --config-file $config_file \ + --gen-name FusedMultiGemmForward \ + --output-dir $output_dir \ + --cutlass-dir $cutlass_dir diff --git a/examples/44_multi_gemm_ir_and_codegen/ir_gen/helper.py b/examples/44_multi_gemm_ir_and_codegen/ir_gen/helper.py new file mode 100644 index 0000000000..d9891404fa --- /dev/null +++ b/examples/44_multi_gemm_ir_and_codegen/ir_gen/helper.py @@ -0,0 +1,135 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +def type_2_cutlass_type(input_type = "fp16"): + # float point type + if input_type == "fp32": + return "float" + if input_type == "bf16": + return "cutlass::bfloat16_t" + if input_type == "fp16": + return "cutlass::half_t" + + # integer type + if(input_type == "int32"): + return "int32_t" + if(input_type == "int8"): + return "int8_t" + + if input_type == 'Row': + return 'cutlass::layout::RowMajor' + if input_type == 'Col': + return 'cutlass::layout::ColumnMajor' + +def cvt_2_cutlass_shape(gemm_shape): + # gemm shape + if len(gemm_shape) == 3: + val = "cutlass::gemm::GemmShape<" \ + + str(gemm_shape[0]) + ", " \ + + str(gemm_shape[1]) + ", " \ + + str(gemm_shape[2]) + ">" + return val + + +def write_2_headfile(filename, file_dir, string): + with open(file_dir + filename, 'w') as f: + f.write("/* Auto Generated code - Do not edit.*/\n\n\n#pragma once\n" + string) + +def var_idx(varaiable, index): + return varaiable + str(index) + + +def list_2_string(input_list, ): + rtn_string = "" + + cnt = 0 + + for element in input_list: + final = ", \n" + if cnt == len(input_list) - 1: + final = "\n" + cnt += 1 + rtn_string += str(element) + final + + return rtn_string + + +def get_epilogue_info(layer_info): + return layer_info['epilogue'] + +def get_epilogue_tp(layer_info): + epilogue_info = get_epilogue_info(layer_info) + return epilogue_info['tp'] + +def get_epilogue_add_bias_or_not(layer_info): + epilogue_info = get_epilogue_info(layer_info) + return epilogue_info['bias']['addbias'] + +def get_epilogue_add_bias_tp(layer_info): + epilogue_info = get_epilogue_info(layer_info) + return epilogue_info['bias']['bias_tp'] + +def get_epilogue_args(layer_info): + epilogue_info = get_epilogue_info(layer_info) + return epilogue_info['args'] + +def get_epilogue_bias_shape(layer_info): + bias_tp = get_epilogue_add_bias_tp(layer_info).lower() + mn_shape = layer_info['mnk'][:-1] + + if bias_tp == 'mat': + mn_shape[0] = 'M' + return mn_shape + elif bias_tp == 'vec': + mn_shape[0] = 1 + return mn_shape + else: + assert(0) + +def get_epilogue_bias_ldm(layer_info): + bias_tp = get_epilogue_add_bias_tp(layer_info).lower() + mn_shape = layer_info['mnk'][:-1] + + c_layout = layer_info['C_format'].lower() + + if c_layout != 'row': + assert(0) + + if bias_tp == 'mat': + return mn_shape[1] + elif bias_tp == 'vec': + return 0 + else: + assert(0) + +def get_epilogue_compute_tp(layer_info): + return layer_info['Acc_tp'] diff --git a/examples/44_multi_gemm_ir_and_codegen/ir_gen/replace_fix_impl_header.py b/examples/44_multi_gemm_ir_and_codegen/ir_gen/replace_fix_impl_header.py new file mode 100644 index 0000000000..bbcd050f02 --- /dev/null +++ b/examples/44_multi_gemm_ir_and_codegen/ir_gen/replace_fix_impl_header.py @@ -0,0 +1,67 @@ +################################################################################################# +# +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +# +################################################################################################# + +import os + +class replace_fix_impl: + def __init__(self, src_dir, dst_dir, cutlass_deps_root): + self.src_dir = src_dir + self.dst_dir = dst_dir + self.cutlass_deps_root = cutlass_deps_root + + + + def gen_code(self): + for sub_dir in os.walk(self.src_dir): + files_in_sub_dir = sub_dir[2] + + src_dirs = sub_dir[0] + output_dirs = self.dst_dir + sub_dir[0][len(self.src_dir):] + + if not os.path.exists(output_dirs): + os.mkdir(output_dirs) + + for f in files_in_sub_dir: + with open(src_dirs +"/" + f, 'r') as current_file: + output_lines = [] + lines = current_file.readlines() + + for line in lines: + if(len(line) >= len("#include \"cutlass") and line[:len("#include \"cutlass")] == "#include \"cutlass"): + new_line = "#include \"" + self.cutlass_deps_root + line[len("#include \""):] + # print(new_line) + output_lines.append(new_line) + else: + output_lines.append(line) + + with open(output_dirs + "/" + f, "w+") as dest_file: + dest_file.writelines(output_lines) diff --git a/examples/44_multi_gemm_ir_and_codegen/leaky_bias.h b/examples/44_multi_gemm_ir_and_codegen/leaky_bias.h new file mode 100644 index 0000000000..10b49049b1 --- /dev/null +++ b/examples/44_multi_gemm_ir_and_codegen/leaky_bias.h @@ -0,0 +1,292 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once +#include + +template +__device__ +T add(T const & a, T const &b){ + return (a + b); +} + +template <> +__device__ +half2 add(half2 const & a, half2 const &b){ + return (__hadd2(a,b)); +} + +template +struct RELU{ + __device__ + T operator()(T const & a){ + return a > T(0) ? a : T(0); + } + __device__ + half2 operator()(half2 const & a){ + float2 a_fp32x2 = __half22float2(a); + a_fp32x2.x = a_fp32x2.x > 0.f ? a_fp32x2.x : 0.f; + a_fp32x2.y = a_fp32x2.y > 0.f ? a_fp32x2.y : 0.f; + if(a_fp32x2.x < 0.f || a_fp32x2.y < 0.f) + printf(" %f %f\n", a_fp32x2.x ,a_fp32x2.y); + return __float22half2_rn(a_fp32x2); + } +}; + +template +struct LEAKY_RELU{ + __device__ + T operator()(T const & a, T const & scale = half(1)){ + return a > T(0) ? a : scale * a; + } + __device__ + half2 operator()(half2 const & a, half const & scale = half(1)){ + half2 zero = __half2half2(half(0)); + half2 gt_zero = __hge2(a, zero); + half2 le_zero = __hle2(a, zero); + + + half2 scale_f16x2 = __half2half2(scale); + half2 mask_scale_f16x2 = __hfma2(le_zero, scale_f16x2, gt_zero); + return __hmul2(a, mask_scale_f16x2); + } +}; + +template +__global__ void leaky_and_activation(half* inout, half* bias, half scale, bool mat_bias){ + + constexpr bool N_MOD_2 = N & 1 ? false : true; + + using Access_tp = typename std::conditional::type; + + constexpr int Access_elements = sizeof(Access_tp) / sizeof(half); + + constexpr int iter = (N + (BLOCKDIM * Access_elements) - 1 ) / (BLOCKDIM * Access_elements); + + LEAKY_RELU Act; + Access_tp src_v[iter]; + Access_tp bias_v[iter]; + + int batch_id = blockIdx.y; + int batch_offset = batch_id * gridDim.x * N; + + for(int i = 0; i < iter; i++){ + int idx = (i * BLOCKDIM + threadIdx.x) * Access_elements; + if (idx < N){ + src_v[i] = *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset); + if (mat_bias) + bias_v[i] = *reinterpret_cast(bias + blockIdx.x * N + idx + batch_offset); + else + bias_v[i] = *reinterpret_cast(bias + idx + batch_id * N); + *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset) = Act(add(src_v[i],bias_v[i]),scale); + } + + } +} + + + +template +__global__ void leaky_and_activation(half* inout, half scale){ + + constexpr bool N_MOD_2 = N & 1 ? false : true; + + using Access_tp = typename std::conditional::type; + + constexpr int Access_elements = sizeof(Access_tp) / sizeof(half); + + constexpr int iter = (N + (BLOCKDIM * Access_elements) - 1 ) / (BLOCKDIM * Access_elements); + + int batch_id = blockIdx.y; + int batch_offset = batch_id * gridDim.x * N; + + LEAKY_RELU Act; + Access_tp src_v[iter]; + + for(int i = 0; i < iter; i++){ + int idx = (i * BLOCKDIM + threadIdx.x) * Access_elements; + if (idx < N){ + src_v[i] = *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset); + *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset) = Act(src_v[i], scale); + } + + } +} + + + +template +void leaky_and_activation(half* inout, half* bias, int m, int b, half scale, bool mat_bias){ + + dim3 grid(m, b); + if (bias == nullptr) + leaky_and_activation<<>>(inout, scale); + else + leaky_and_activation<<>>(inout, bias, scale, mat_bias); +} + +template +__global__ void relu_and_activation(half* inout, half* bias, bool mat_bias){ + + constexpr bool N_MOD_2 = N & 1 ? false : true; + + using Access_tp = typename std::conditional::type; + + constexpr int Access_elements = sizeof(Access_tp) / sizeof(half); + + constexpr int iter = (N + (BLOCKDIM * Access_elements) - 1 ) / (BLOCKDIM * Access_elements); + + RELU Act; + Access_tp src_v[iter]; + Access_tp bias_v[iter]; + + int batch_id = blockIdx.y; + int batch_offset = batch_id * gridDim.x * N; + + for(int i = 0; i < iter; i++){ + int idx = (i * BLOCKDIM + threadIdx.x) * Access_elements; + if (idx < N){ + src_v[i] = *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset); + if (mat_bias) + bias_v[i] = *reinterpret_cast(bias + blockIdx.x * N + idx + batch_offset); + else + bias_v[i] = *reinterpret_cast(bias + idx + batch_id * N); + *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset) = Act(add(src_v[i],bias_v[i])); + } + + } +} + + + +template +__global__ void relu_and_activation(half* inout){ + + constexpr bool N_MOD_2 = N & 1 ? false : true; + + using Access_tp = typename std::conditional::type; + + constexpr int Access_elements = sizeof(Access_tp) / sizeof(half); + + constexpr int iter = (N + (BLOCKDIM * Access_elements) - 1 ) / (BLOCKDIM * Access_elements); + + int batch_id = blockIdx.y; + int batch_offset = batch_id * gridDim.x * N; + + RELU Act; + Access_tp src_v[iter]; + + for(int i = 0; i < iter; i++){ + int idx = (i * BLOCKDIM + threadIdx.x) * Access_elements; + if (idx < N){ + src_v[i] = *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset); + *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset) = Act(src_v[i]); + } + + } +} + + + +template +void relu_and_activation(half* inout, half* bias, int m, int b, bool mat_bias){ + dim3 grid(m, b); + if (bias == nullptr) + relu_and_activation<<>>(inout); + else + relu_and_activation<<>>(inout, bias, mat_bias); +} + + +template +__global__ void identity_and_activation(half* inout, half* bias, bool mat_bias){ + + constexpr bool N_MOD_2 = N & 1 ? false : true; + + using Access_tp = typename std::conditional::type; + + constexpr int Access_elements = sizeof(Access_tp) / sizeof(half); + + constexpr int iter = (N + (BLOCKDIM * Access_elements) - 1 ) / (BLOCKDIM * Access_elements); + + int batch_id = blockIdx.y; + int batch_offset = batch_id * gridDim.x * N; + + Access_tp src_v[iter]; + Access_tp bias_v[iter]; + + for(int i = 0; i < iter; i++){ + int idx = (i * BLOCKDIM + threadIdx.x) * Access_elements; + if (idx < N){ + src_v[i] = *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset); + if (mat_bias) + bias_v[i] = *reinterpret_cast(bias + blockIdx.x * N + idx + batch_offset); + else + bias_v[i] = *reinterpret_cast(bias + idx + batch_id * N); + *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset) = (add(src_v[i],bias_v[i])); + } + + } +} + +template +__global__ void identity_and_activation(half* inout){ + + constexpr bool N_MOD_2 = N & 1 ? false : true; + + using Access_tp = typename std::conditional::type; + + constexpr int Access_elements = sizeof(Access_tp) / sizeof(half); + + constexpr int iter = (N + (BLOCKDIM * Access_elements) - 1 ) / (BLOCKDIM * Access_elements); + + int batch_id = blockIdx.y; + int batch_offset = batch_id * gridDim.x * N; + Access_tp src_v[iter]; + + for(int i = 0; i < iter; i++){ + int idx = (i * BLOCKDIM + threadIdx.x) * Access_elements; + if (idx < N){ + src_v[i] = *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset); + *reinterpret_cast(inout + blockIdx.x * N + idx + batch_offset) = (src_v[i]); + } + + } +} + +template +void identity_and_activation(half* inout, half* bias, int m, int b, bool mat_bias){ + dim3 grid(m, b); + if (bias == nullptr) + identity_and_activation<<>>(inout); + else + identity_and_activation<<>>(inout, bias, mat_bias); +} diff --git a/examples/44_multi_gemm_ir_and_codegen/utils.h b/examples/44_multi_gemm_ir_and_codegen/utils.h new file mode 100644 index 0000000000..2b05ae9367 --- /dev/null +++ b/examples/44_multi_gemm_ir_and_codegen/utils.h @@ -0,0 +1,94 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once +#define TI(tag) \ + cudaEvent_t _event_start_ ##tag; \ + cudaEvent_t _event_end_ ##tag; \ + float _event_time_ ##tag; \ + cudaEventCreate(& _event_start_ ##tag); \ + cudaEventCreate(& _event_end_ ##tag); \ + cudaEventRecord(_event_start_ ##tag); + +#define TO(tag, str, times) \ + cudaEventRecord(_event_end_ ##tag); \ + cudaEventSynchronize(_event_end_ ##tag); \ + cudaEventElapsedTime(&_event_time_ ##tag, _event_start_ ##tag, _event_end_ ##tag); \ + float _event_time_once_ ##tag = _event_time_ ##tag / times; \ + printf("%20s:\t %10.3fus\t", str, _event_time_once_ ##tag * 1000); \ + cudaDeviceSynchronize(); \ + printf("%20s string: %s\n",str, cudaGetErrorString(cudaGetLastError())); + +template +struct memory_unit{ + T* host_ptr; + T* device_ptr; + int size_bytes; + int elements; + void h2d(){ + cudaMemcpy(device_ptr, host_ptr, size_bytes, cudaMemcpyHostToDevice); + } + void d2h(){ + cudaMemcpy(host_ptr, device_ptr, size_bytes, cudaMemcpyDeviceToHost); + } + void free_all(){ + free(host_ptr); + cudaFree(device_ptr); + } + memory_unit(int elements_): size_bytes(elements_ * sizeof(T)), elements(elements_){ + host_ptr = (T*) malloc(elements_ * sizeof(T)); + cudaMalloc((void**)&device_ptr, elements_ * sizeof(T)); + } + void init(int abs_range = 1){ + for(int i = 0; i < elements; i++){ + host_ptr[i] = T(rand() % 100 / float(100) * 2 * abs_range - abs_range); + } + h2d(); + } +}; + +template +int check_result(T * a, T * b, int N){ + int cnt = 0; + for(int i = 0; i < N; i ++){ + float std = float(a[i]); + float my = float(b[i]); + + if(abs(std - my) / abs(std) > 1e-2) + { + // printf("my: %f , std: %f\n", my, std); + cnt++; + } + + } + printf("total err: %d / %d\n", cnt, N); + return cnt; +} diff --git a/examples/45_dual_gemm/CMakeLists.txt b/examples/45_dual_gemm/CMakeLists.txt new file mode 100644 index 0000000000..de704ed2b1 --- /dev/null +++ b/examples/45_dual_gemm/CMakeLists.txt @@ -0,0 +1,36 @@ + +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + + +cutlass_example_add_executable( + 45_dual_gemm + dual_gemm.cu + ) + diff --git a/examples/45_dual_gemm/device/dual_gemm.h b/examples/45_dual_gemm/device/dual_gemm.h new file mode 100644 index 0000000000..f48073597f --- /dev/null +++ b/examples/45_dual_gemm/device/dual_gemm.h @@ -0,0 +1,499 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Performs a dual gemm in one fused kernel: +``` +D0 = epilogue0(X @ B0, C0) +D1 = epilogue1(X @ B1, C1) +D2 = element_wise(D0, D1) +``` +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/arch.h" +#include "cutlass/device_kernel.h" + +#include "cutlass/gemm/threadblock/threadblock_swizzle.h" + +#include "cutlass/gemm/device/default_gemm_configuration.h" +#include "cutlass/gemm/threadblock/default_mma.h" +#include "cutlass/epilogue/thread/linear_combination_relu.h" +#include "cutlass/epilogue/threadblock/default_epilogue_tensor_op.h" + +#include "../kernel/dual_gemm.h" +#include "../dual_gemm_common.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Element type for A matrix operand + typename ElementA_, + /// Layout type for A matrix operand + typename LayoutA_, + /// Element type for B matrix operand + typename ElementB_, + /// Layout type for B0 matrix operand + typename LayoutB0_, + /// Layout type for B1 matrix operand + typename LayoutB1_, + /// Element type for C and D matrix operands + typename ElementC_, + /// Layout type for C and D matrix operands + typename LayoutC_, + /// Element type for internal accumulation + typename ElementAccumulator_, + /// Operator class tag + typename OperatorClass_, + /// Tag indicating architecture to tune for + typename ArchTag_, + /// Threadblock-level tile size (concept: GemmShape) + typename ThreadblockShape_, + /// Warp-level tile size (concept: GemmShape) + typename WarpShape_, + /// Instruction-level tile size (concept: GemmShape) + typename InstructionShape_, + /// Epilogue output operator + typename EpilogueOutputOp0_, + typename EpilogueOutputOp1_, + typename EpilogueOutputOp2_, + /// Threadblock-level swizzling operator + typename ThreadblockSwizzle_ = threadblock::GemmIdentityThreadblockSwizzle<>, + /// Number of stages used in the pipelined mainloop + int Stages = + DefaultGemmConfiguration::kStages, + bool StoreD0 = true, + bool StoreD1 = true, + /// If true, kernel supports split-K with serial reduction + bool SplitKSerial = false, + /// Access granularity of A matrix in units of elements + int AlignmentA = + DefaultGemmConfiguration::kAlignmentA, + /// Access granularity of B matrix in units of elements + int AlignmentB = + DefaultGemmConfiguration::kAlignmentB, + /// Operation performed by GEMM + typename Operator_ = typename DefaultGemmConfiguration< + OperatorClass_, ArchTag_, ElementA_, ElementB_, ElementC_, + ElementAccumulator_>::Operator> +class DualGemm { + public: + + using ElementA = ElementA_; + using LayoutA = LayoutA_; + using TensorRefA = TensorRef; + using ElementB = ElementB_; + using LayoutB0 = LayoutB0_; + using LayoutB1 = LayoutB1_; + using TensorRefB0 = TensorRef; + using TensorRefB1 = TensorRef; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using TensorRefC = TensorRef; + using TensorRefD = TensorRef; + using ElementAccumulator = ElementAccumulator_; + using OperatorClass = OperatorClass_; + using ArchTag = ArchTag_; + using ThreadblockShape = ThreadblockShape_; + using WarpShape = WarpShape_; + using InstructionShape = InstructionShape_; + using EpilogueOutputOp0 = EpilogueOutputOp0_; + using EpilogueOutputOp1 = EpilogueOutputOp1_; + using EpilogueOutputOp2 = EpilogueOutputOp2_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using Operator = Operator_; + static int const kStages = Stages; + static int const kAlignmentA = AlignmentA; + static int const kAlignmentB = AlignmentB; + static int const kAlignmentC = EpilogueOutputOp1::kCount; + static bool const kSplitKSerial = SplitKSerial; + static bool constexpr kStoreD0 = StoreD0; + static bool constexpr kStoreD1 = StoreD1; + static ComplexTransform const kTransformA = ComplexTransform::kNone; + static ComplexTransform const kTransformB = ComplexTransform::kNone; + + using LayoutScaleBias = layout::RowMajor; + /// Define the kernel + /// Define the threadblock-scoped matrix multiply-accumulate + static_assert(ArchTag::kMinComputeCapability >= 80, "Only multistage is implemented"); + static_assert(kStages >= 3, "Only multistage is implemented"); + using Mma0 = typename cutlass::gemm::threadblock::DefaultMma< + ElementA, LayoutA, kAlignmentA, ElementB, LayoutB0, kAlignmentB, + ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, + ThreadblockShape, WarpShape, + InstructionShape, Stages, Operator>::ThreadblockMma; + using Mma1 = typename cutlass::gemm::threadblock::DefaultMma< + ElementA, LayoutA, kAlignmentA, ElementB, LayoutB1, kAlignmentB, + ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, ArchTag, + ThreadblockShape, WarpShape, + InstructionShape, Stages, Operator>::ThreadblockMma; + using DualMma = threadblock::DualMmaMultistage< + typename Mma0::Shape, + typename Mma0::IteratorA, + typename Mma0::SmemIteratorA, + Mma0::kCacheOpA, + typename Mma0::IteratorB, + typename Mma0::SmemIteratorB, + Mma0::kCacheOpB, + typename Mma1::IteratorB, + typename Mma1::SmemIteratorB, + typename Mma0::ElementC, + typename Mma0::LayoutC, + typename Mma0::Policy, + typename Mma1::Policy, + Mma0::kStages, + SharedMemoryClearOption::kNone + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + /// Define the epilogue + using Epilogue0 = + typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, typename DualMma::Operator0, kPartitionsK, EpilogueOutputOp0, + EpilogueOutputOp0::kCount>::Epilogue; + using Epilogue1 = + typename cutlass::epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, typename DualMma::Operator1, kPartitionsK, EpilogueOutputOp1, + EpilogueOutputOp1::kCount>::Epilogue; + + /// Define the kernel-level GEMM operator. + using DualGemmKernel = kernel::DualGemm< + DualMma, + Epilogue0, Epilogue1, EpilogueOutputOp2, + ThreadblockSwizzle, kSplitKSerial, + kStoreD0, kStoreD1>; + + /// Argument structure + struct Arguments { + + // + // Data members + // + + DualGemmMode mode; + GemmCoord problem_size; + TensorRef ref_A0; + TensorRef ref_B0; + TensorRef ref_C0; + TensorRef ref_D0; + TensorRef ref_B1; + TensorRef ref_C1; + TensorRef ref_D1; + TensorRef ref_D2; + typename EpilogueOutputOp0::Params epilogue0; + typename EpilogueOutputOp1::Params epilogue1; + typename EpilogueOutputOp2::Params epilogue2; + int split_k_slices; + + int batch_count; + int64_t batch_stride_A; + int64_t batch_stride_B0; + int64_t batch_stride_B1; + int64_t batch_stride_C; + int64_t batch_stride_D; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments(): problem_size(0, 0, 0), split_k_slices(1) { + + } + + /// Constructs an Arguments structure + CUTLASS_HOST_DEVICE + Arguments( + DualGemmMode mode, + GemmCoord problem_size_, + TensorRef ref_A0_, + TensorRef ref_B0_, + TensorRef ref_C0_, + TensorRef ref_D0_, + TensorRef ref_B1_, + TensorRef ref_C1_, + TensorRef ref_D1_, + TensorRef ref_D2_, + typename EpilogueOutputOp0::Params epilogue0_ = + typename EpilogueOutputOp0::Params(), + typename EpilogueOutputOp1::Params epilogue1_ = + typename EpilogueOutputOp1::Params(), + typename EpilogueOutputOp2::Params epilogue2_ = + typename EpilogueOutputOp2::Params(), + int split_k_slices_ = 1, + int batch_count = 1, + int64_t batch_stride_A = 0, + int64_t batch_stride_B0 = 0, + int64_t batch_stride_B1 = 0, + int64_t batch_stride_C = 0, + int64_t batch_stride_D = 0 + ): + mode(mode), + problem_size(problem_size_), + ref_A0(ref_A0_), + ref_B0(ref_B0_), + ref_C0(ref_C0_), + ref_D0(ref_D0_), + ref_B1(ref_B1_), + ref_C1(ref_C1_), + ref_D1(ref_D1_), + ref_D2(ref_D2_), + epilogue0(epilogue0_), + epilogue1(epilogue1_), + epilogue2(epilogue2_), + split_k_slices(split_k_slices_), + batch_count(batch_count), + batch_stride_A(batch_stride_A), + batch_stride_B0(batch_stride_B0), + batch_stride_B1(batch_stride_B1), + batch_stride_C(batch_stride_C), + batch_stride_D(batch_stride_D) { + + } + }; + +private: + + /// Kernel parameters object + typename DualGemmKernel::Params params_; + +public: + + /// Constructs the GEMM. + DualGemm() = default; + + /// Determines whether the GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + if (args.mode == DualGemmMode::kBatched && kSplitKSerial) { + return Status::kErrorInvalidProblem; + } + if (!kSplitKSerial && args.split_k_slices > 1) { + return Status::kErrorInvalidProblem; + } + if (kStoreD0 != (args.ref_D0.data() != nullptr)) { + return Status::kErrorInternal; + } + if (kStoreD1 != (args.ref_D1.data() != nullptr)) { + return Status::kErrorInternal; + } + + Status status = DualGemmKernel::can_implement( + args.problem_size, + args.ref_A0.non_const_ref(), + args.ref_B0.non_const_ref(), + args.ref_C0.non_const_ref(), + args.ref_D0, + args.ref_B1.non_const_ref(), + args.ref_C1.non_const_ref(), + args.ref_D1, + args.ref_D2 + ); + + if (status != Status::kSuccess) { + return status; + } + + return Status::kSuccess; + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + + size_t bytes = 0; + + if (kSplitKSerial && args.split_k_slices > 1) { + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord tiled_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.split_k_slices); + + bytes += sizeof(int) * size_t(tiled_shape.m()) * size_t(tiled_shape.n()); + } + + return bytes; + } + + /// Initializes GEMM state from arguments. + Status initialize(Arguments const &args, void *workspace = nullptr, cudaStream_t stream = nullptr) { + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord grid_shape = threadblock_swizzle.get_tiled_shape( + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.mode == DualGemmMode::kBatched ? args.batch_count : args.split_k_slices); + + if (kSplitKSerial) { + if (args.split_k_slices > 1) { + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + + size_t bytes = get_workspace_size(args); + + cudaError_t result = cudaMemsetAsync(workspace, 0, bytes, stream); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + } + else { + + if (args.split_k_slices > 1) { + return Status::kErrorInvalidProblem; + } + } + + // Initialize the Params structure + params_ = typename DualGemmKernel::Params{ + args.mode, + args.problem_size, + grid_shape, + args.ref_A0.non_const_ref(), + args.ref_B0.non_const_ref(), + args.ref_C0.non_const_ref(), + args.ref_D0, + args.ref_B1.non_const_ref(), + args.ref_C1.non_const_ref(), + args.ref_D1, + args.ref_D2, + args.epilogue0, + args.epilogue1, + args.epilogue2, + reinterpret_cast(workspace), + args.batch_stride_A, + args.batch_stride_B0, + args.batch_stride_B1, + args.batch_stride_C, + args.batch_stride_D, + }; + + return Status::kSuccess; + } + + /// Lightweight update given a subset of arguments + Status update(Arguments const &args, void *workspace = nullptr) { + + if (kSplitKSerial && args.split_k_slices > 1) { + if (!workspace) { + return Status::kErrorWorkspaceNull; + } + } + + params_.ref_A0.reset(args.ref_A0.non_const_ref().data()); + params_.ref_B0.reset(args.ref_B0.non_const_ref().data()); + params_.ref_C0.reset(args.ref_C0.non_const_ref().data()); + params_.ref_D0.reset(args.ref_D0.data()); + params_.ref_B1.reset(args.ref_B1.non_const_ref().data()); + params_.ref_C1.reset(args.ref_C1.non_const_ref().data()); + params_.ref_D1.reset(args.ref_D1.data()); + params_.ref_D2.reset(args.ref_D2.data()); + params_.output_op_0 = args.epilogue0; + params_.output_op_1 = args.epilogue1; + params_.output_op_2 = args.epilogue2; + params_.semaphore = reinterpret_cast(workspace); + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); + dim3 block(DualGemmKernel::kThreadCount, 1, 1); + + cudaError_t result; + + int smem_size = int(sizeof(typename DualGemmKernel::SharedStorage)); + if (smem_size >= (48 << 10)) { + result = cudaFuncSetAttribute(Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + + cutlass::Kernel<<>>(params_); + + result = cudaGetLastError(); + + return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } +}; + +} // namespace device +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/examples/45_dual_gemm/dual_gemm.cu b/examples/45_dual_gemm/dual_gemm.cu new file mode 100644 index 0000000000..8043addec1 --- /dev/null +++ b/examples/45_dual_gemm/dual_gemm.cu @@ -0,0 +1,460 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief CUTLASS Dual-GEMM Example. + + Fused kernel that outputs `D0` and `D1`. + We assume that B0/B1 have the same shape/layout + +``` +D0 = epilogue0(X @ B0, C0) +D1 = epilogue1(X @ B1, C1) +D2 = element_wise(D0, D1) +``` + D0 and D1 will be optionally stored in gmem (`kStoreD0` / `kStoreD1`) +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/gemm.h" + +#include "device/dual_gemm.h" +#include "thread/left_silu_and_mul.h" +#include "dual_gemm_run.h" +#include "test_run.h" + + +//////////////////////////////////////////////////////////////////////////////// + +cutlass::gemm::GemmCoord problem_size(4096, 4096, 8192); +cutlass::gemm::GemmCoord batch_problem_size(321, 256, 512); + +constexpr int kStages = 3; +constexpr bool kSplitKSerial = false; +constexpr bool kUseBias = true; +constexpr int kBatchCount = 37; + + +#if 0 +using ElementOperandA = cutlass::bfloat16_t; +using ElementOperandB = cutlass::bfloat16_t; +using ElementOutput = cutlass::bfloat16_t; +using ElementAccumulator = float; +using ElementCompute = float; +#else +using ElementOperandA = cutlass::half_t; +using ElementOperandB = cutlass::half_t; +using ElementOutput = cutlass::half_t; +using ElementAccumulator = cutlass::half_t; +using ElementCompute = cutlass::half_t; +#endif + +constexpr auto kScaleType = kUseBias ? cutlass::epilogue::thread::ScaleType::NoBetaScaling : ( + // No bias + kSplitKSerial ? cutlass::epilogue::thread::ScaleType::Default : cutlass::epilogue::thread::ScaleType::Nothing +); +using EpilogueOutputOp0 = cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute, + kScaleType +>; +using EpilogueOutputOp1 = cutlass::epilogue::thread::LinearCombination< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + ElementCompute, + kScaleType +>; +using EpilogueOutputOp2 = cutlass::epilogue::thread::LeftSiLUAndMul< + ElementOutput, + 128 / cutlass::sizeof_bits::value, + ElementOutput, + ElementCompute +>; + +const ElementCompute alpha0 = ElementCompute(1); +const ElementCompute beta0 = ElementCompute(kUseBias ? 1 : 0); +const ElementCompute alpha1 = ElementCompute(1); +const ElementCompute beta1 = ElementCompute(kUseBias ? 1 : 0); + +bool run_nonfused_gemm_f16_sm80() { + using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; + using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + using Gemm0 = cutlass::gemm::device::Gemm< + ElementOperandA, + cutlass::layout::RowMajor, + ElementOperandB, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp0, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, + kStages, + 8, + 8, + kSplitKSerial + >; + using Gemm1 = cutlass::gemm::device::Gemm< + ElementOperandA, + cutlass::layout::RowMajor, + ElementOperandB, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp1, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, + kStages, + 8, + 8, + kSplitKSerial + >; + + NonFusedDualGemmRun nonFusedGemm; + + std::cout << "Running Non-fused GEMMs FP16 TN GEMMs...\n"; + + bool pass = nonFusedGemm.run( + problem_size, + alpha0, + beta0, + alpha1, + beta1, + true /* is_profiling */ + ); + + if(pass) + std::cout << "Pass\n"; + else + std::cout << "Fail\n"; + + return pass; +} + +template +struct LeftSiLUAndMul { + struct Params{}; + CUTLASS_HOST_DEVICE LeftSiLUAndMul(Params p) {} + + CUTLASS_HOST_DEVICE void set_k_partition(int, int) {} + + CUTLASS_HOST_DEVICE T operator() ( + T const &lhs, + T const &rhs) const { + cutlass::epilogue::thread::SiLu silu; + cutlass::multiplies mul; + auto silu_lhs = silu(lhs); + return mul(silu_lhs, rhs); + } + + template + CUTLASS_HOST_DEVICE cutlass::Array operator() ( + cutlass::Array const &lhs, + cutlass::Array const &rhs) const { + cutlass::epilogue::thread::SiLu silu; + cutlass::multiplies mul; + auto silu_lhs = silu(lhs); + return mul(silu_lhs, rhs); + } +}; + +bool run_fused_gemm_f16_sm80_shmem() { + using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + // Optionally, we might not need intermediate GEMM outputs + constexpr bool kStoreD0 = true; + constexpr bool kStoreD1 = true; + + using DualGemm = cutlass::gemm::device::DualGemm< + ElementOperandA, + cutlass::layout::RowMajor, + ElementOperandB, + cutlass::layout::ColumnMajor, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp0, + EpilogueOutputOp1, + EpilogueOutputOp2, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, + kStages, + kStoreD0, + kStoreD1, + kSplitKSerial + >; + + DualFusedGemmRun fusedGemm; + + std::cout << "Running Fused FP16 TN GEMMs + Epilogue2...\n"; + + bool passed = fusedGemm.run( + problem_size, + alpha0, + beta0, + alpha1, + beta1 + ); + + if(passed) + std::cout << "Pass\n"; + else + std::cout << "Fail\n"; + + return passed; +} + +bool run_batched_fused_gemm_f16_sm80_shmem() { + using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + // Optionally, we might not need intermediate GEMM outputs + constexpr bool kStoreD0 = true; + constexpr bool kStoreD1 = true; + + using DualGemm = cutlass::gemm::device::DualGemm< + ElementOperandA, + cutlass::layout::RowMajor, + ElementOperandB, + cutlass::layout::ColumnMajor, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp0, + EpilogueOutputOp1, + EpilogueOutputOp2, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, + kStages, + kStoreD0, + kStoreD1, + kSplitKSerial + >; + + DualFusedGemmRun fusedGemm; + + std::cout << "Running Batched Fused FP16 TN GEMMs + Epilogue2...\n"; + + bool passed = fusedGemm.run( + batch_problem_size, + alpha0, + beta0, + alpha1, + beta1, + kBatchCount, + false, /* broadcast_b1 */ + false /* is_profiling */ + ); + + if(passed) + std::cout << "Pass\n"; + else + std::cout << "Fail\n"; + + return passed; +} + +bool run_broadcast_fused_gemm_f16_sm80_shmem() { + using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + // Optionally, we might not need intermediate GEMM outputs + constexpr bool kStoreD0 = true; + constexpr bool kStoreD1 = true; + + using DualGemm = cutlass::gemm::device::DualGemm< + ElementOperandA, + cutlass::layout::RowMajor, + ElementOperandB, + // different LayoutB0 and B1 + cutlass::layout::RowMajor, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp0, + EpilogueOutputOp1, + EpilogueOutputOp2, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, + kStages, + kStoreD0, + kStoreD1, + kSplitKSerial + >; + + DualFusedGemmRun fusedGemm; + + std::cout << "Running Broadcast Fused FP16 TN GEMMs + Epilogue2...\n"; + + bool passed = fusedGemm.run( + problem_size, + alpha0, + beta0, + alpha1, + beta1, + 1, /* batch_count */ + true, /* broadcast_b1 */ + true /* is_profiling */ + ); + + if(passed) + std::cout << "Pass\n"; + else + std::cout << "Fail\n"; + + return passed; +} + +bool run_batched_broadcast_fused_gemm_f16_sm80_shmem() { + using ThreadblockShape = cutlass::gemm::GemmShape<128, 64, 32>; + using WarpShape = cutlass::gemm::GemmShape<64, 32, 32>; + using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; + + // Optionally, we might not need intermediate GEMM outputs + constexpr bool kStoreD0 = true; + constexpr bool kStoreD1 = true; + + using DualGemm = cutlass::gemm::device::DualGemm< + ElementOperandA, + cutlass::layout::RowMajor, + ElementOperandB, + // different LayoutB0 and B1 + cutlass::layout::RowMajor, + cutlass::layout::ColumnMajor, + ElementOutput, + cutlass::layout::RowMajor, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp0, + EpilogueOutputOp1, + EpilogueOutputOp2, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<1>, + kStages, + kStoreD0, + kStoreD1, + kSplitKSerial + >; + + DualFusedGemmRun fusedGemm; + + std::cout << "Running Batch Broadcast Fused FP16 TN GEMMs + Epilogue2...\n"; + + bool passed = fusedGemm.run( + batch_problem_size, + alpha0, + beta0, + alpha1, + beta1, + kBatchCount, + true, /* broadcast_b1 */ + false /* is_profiling */ + ); + + if(passed) + std::cout << "Pass\n"; + else + std::cout << "Fail\n"; + + return passed; +} + +int main() { + + std::vectorfuncs = { + &run_nonfused_gemm_f16_sm80, + &run_fused_gemm_f16_sm80_shmem, + &run_batched_fused_gemm_f16_sm80_shmem, + &run_broadcast_fused_gemm_f16_sm80_shmem, + &run_batched_broadcast_fused_gemm_f16_sm80_shmem + }; + + std::string test_name = ( + "dual-gemm f16 bias=" + + std::to_string(kUseBias) + + " split_k_serial=" + + std::to_string(kSplitKSerial) + + " batch_count=" + + std::to_string(kBatchCount) + ); + + return testRun(80, funcs, test_name); +} + + + +//////////////////////////////////////////////////////////////////////////////// diff --git a/examples/45_dual_gemm/dual_gemm_common.h b/examples/45_dual_gemm/dual_gemm_common.h new file mode 100644 index 0000000000..41f5cfea6a --- /dev/null +++ b/examples/45_dual_gemm/dual_gemm_common.h @@ -0,0 +1,52 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines common types used for all DualGemm operators. +*/ +#pragma once + +namespace cutlass { +namespace gemm { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +enum class DualGemmMode { + kGemm, + kBatched, + kInvalid +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace gemm +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/45_dual_gemm/dual_gemm_run.h b/examples/45_dual_gemm/dual_gemm_run.h new file mode 100644 index 0000000000..b53ee80668 --- /dev/null +++ b/examples/45_dual_gemm/dual_gemm_run.h @@ -0,0 +1,938 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include +#include +#include + +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_relu.h" + +#include "cutlass/platform/platform.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/device/gemm_universal.h" + +#include "dual_gemm_common.h" +#include "helper.h" + +#define CHECK_GT(val1, val2) \ + if((val1) <= (val2)) \ + std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_GT failed\n"; +#define CHECK_TRUE(val) \ + if(!(val)) \ + std::cerr << __FILE__ << " " << __LINE__ << ": CHECK_TRUE failed\n"; + +template < + typename OutputOp, + typename Element, + typename Layout> +struct TensorEpilogueForEachFunc { + /// View type + using TensorView = cutlass::TensorView; + + /// Coordinate in tensor's index space + using TensorCoord = typename TensorView::TensorCoord; + + /// Parameters structure + struct Params { + + // + // Data members + // + + TensorView view_x0; + TensorView view_x1; + TensorView view_y; + OutputOp output_op; + + + // + // Methods + // + + Params( + TensorView view_x0_ = TensorView(), + TensorView view_x1_ = TensorView(), + TensorView view_y_ = TensorView(), + OutputOp output_op_ = OutputOp(typename OutputOp::Params{}) + ): + view_x0(view_x0_), view_x1(view_x1_), view_y(view_y_), output_op(output_op_) { + } + }; + + Params params; + + CUTLASS_DEVICE + TensorEpilogueForEachFunc(Params const ¶ms): params(params) { + + } + + CUTLASS_DEVICE + void operator()(TensorCoord const &coord) { + Element const & x0 = params.view_x0.at(coord); + Element const & x1 = params.view_x1.at(coord); + Element& y = params.view_y.at(coord); + y = params.output_op(x0, x1); + } +}; + +template < + typename OutputOp, + typename Element, + typename Layout> +void TensorEpilogueForEach( + cutlass::TensorView x0, + cutlass::TensorView x1, + cutlass::TensorView y) { + + using Func = TensorEpilogueForEachFunc; + using Params = typename Func::Params; + + cutlass::reference::device::TensorForEach( + y.extent(), + Params(x0, x1, y) + ); +} + +//////////////////////////////////////////////////////////////////////////////// + +template +struct NonFusedDualGemmRun +{ + + using Gemm0 = Gemm0_; + using Gemm1 = Gemm1_; + using ElementAccumulator = typename Gemm0::ElementAccumulator; + using ElementCompute = typename Gemm0::GemmKernel::Epilogue::OutputOp::ElementCompute; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + cutlass::Distribution::Kind init_Bias; + uint64_t seed; + + // + // Methods + // + + NonFusedDualGemmRun( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), init_Bias(init_Bias_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, 2, -2, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else if (dist_kind == cutlass::Distribution::AllZeros) { + cutlass::reference::host::TensorFill(view, Element(0)); + } + else if (dist_kind == cutlass::Distribution::AllOnes) { + cutlass::reference::host::TensorFill(view, Element(1)); + } + else { + std::cerr << "Not implemented\n"; + return false; + } + + return true; + } + + + + + /// Executes one test + bool run( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha0 = ElementCompute(1), + ElementCompute beta0 = ElementCompute(0), + ElementCompute alpha1 = ElementCompute(1), + ElementCompute beta1 = ElementCompute(0), + bool is_profiling = true, + bool relu = false, + int warm_ups = 1, + int runs = 100) { + + // + // Allocate the GEMM workspace + // + + cutlass::HostTensor< + typename Gemm0::ElementA, + typename Gemm0::LayoutA> tensor_A0(problem_size.mk()); + + cutlass::HostTensor< + typename Gemm0::ElementB, + typename Gemm0::LayoutB> tensor_B0(problem_size.kn()); + + cutlass::HostTensor< + typename Gemm0::ElementC, + typename Gemm0::LayoutC> tensor_C0(problem_size.mn()); + + cutlass::HostTensor< + typename Gemm1::ElementC, + typename Gemm0::LayoutC> tensor_Bias0({1, problem_size.n()}); + + cutlass::HostTensor< + typename Gemm0::ElementC, + typename Gemm0::LayoutC> tensor_D0(problem_size.mn()); + + cutlass::HostTensor< + typename Gemm0::ElementC, + typename Gemm0::LayoutC> reference_D0(problem_size.mn()); + + cutlass::HostTensor< + typename Gemm1::ElementB, + typename Gemm1::LayoutB> tensor_B1(problem_size.kn()); + + cutlass::HostTensor< + typename Gemm1::ElementC, + typename Gemm1::LayoutC> tensor_C1(problem_size.mn()); + + cutlass::HostTensor< + typename Gemm1::ElementC, + typename Gemm1::LayoutC> tensor_Bias1({1, problem_size.n()}); + + cutlass::HostTensor< + typename Gemm1::ElementC, + typename Gemm1::LayoutC> tensor_D1(problem_size.mn()); + + cutlass::HostTensor< + typename Gemm1::ElementC, + typename Gemm1::LayoutC> reference_D1(problem_size.mn()); + + + CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019)); + CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2018)); + CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017)); + CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2014)); + CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2016)); + CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015)); + CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2013)); + + cutlass::reference::host::TensorFill( + tensor_D0.host_view()); + cutlass::reference::host::TensorFill( + tensor_D1.host_view()); + cutlass::reference::host::TensorFill( + reference_D0.host_view()); + cutlass::reference::host::TensorFill( + reference_D1.host_view()); + + tensor_A0.sync_device(); + tensor_B0.sync_device(); + tensor_C0.sync_device(); + tensor_Bias0.sync_device(); + tensor_D0.sync_device(); + reference_D0.sync_device(); + tensor_B1.sync_device(); + tensor_C1.sync_device(); + tensor_Bias1.sync_device(); + tensor_D1.sync_device(); + reference_D1.sync_device(); + + // + // Initialize the GEMM operator + // + + int split_k_slices = Gemm0::kSplitKSerial ? 2 : 1; + typename Gemm0::Arguments arguments_0{ + problem_size, + tensor_A0.device_ref(), + tensor_B0.device_ref(), + {tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)}, + tensor_D0.device_ref(), + {alpha0, beta0}, + split_k_slices + }; + + split_k_slices = Gemm1::kSplitKSerial ? 2 : 1; + typename Gemm1::Arguments arguments_1{ + problem_size, + tensor_A0.device_ref(), + tensor_B1.device_ref(), + {tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)}, + tensor_D1.device_ref(), + {alpha1, beta1}, + split_k_slices + }; + + + Gemm0 gemm_op_0; + Gemm1 gemm_op_1; + + // Allocate workspace memory + cutlass::device_memory::allocation workspace0(gemm_op_0.get_workspace_size(arguments_0)); + cutlass::device_memory::allocation workspace1(gemm_op_1.get_workspace_size(arguments_1)); + + cutlass::Status status = gemm_op_0.initialize(arguments_0, workspace0.get()); + + CUTLASS_CHECK(status); + + status = gemm_op_1.initialize(arguments_1, workspace1.get()); + + CUTLASS_CHECK(status); + + for(int i = 0; i < warm_ups; i++) { + status = gemm_op_0(); + CUTLASS_CHECK(status); + status = gemm_op_1(); + CUTLASS_CHECK(status); + } + + if (is_profiling) { + // + // Profile the GEMM + // + + cudaEvent_t start, stop1, stop2; + cudaEventCreate(&start); + cudaEventCreate(&stop1); + cudaEventCreate(&stop2); + + cudaEventRecord(start); + + for(int i = 0; i < runs; i++) { + status = gemm_op_0(); + + CUTLASS_CHECK(status); + } + cudaEventRecord(stop1); + for(int i = 0; i < runs; i++) { + status = gemm_op_1(); + + CUTLASS_CHECK(status); + } + + cudaEventRecord(stop2); + cudaDeviceSynchronize(); + float gemm0Time, gemm1Time, totalTime; + cudaEventElapsedTime(&gemm0Time, start, stop1); + cudaEventElapsedTime(&gemm1Time, stop1, stop2); + cudaEventElapsedTime(&totalTime, start, stop2); + std::cout << "gemm 0 time " << gemm0Time / (float)runs << " ms\n"; + std::cout << "gemm 1 time " << gemm1Time / (float)runs << " ms\n"; + std::cout << "Non-fusion GEMM only time " << totalTime / (float)runs << " ms\n"; + } + + tensor_D0.sync_host(); + tensor_D1.sync_host(); + + // + // Verify + // + cutlass::reference::device::Gemm< + typename Gemm0::ElementA, typename Gemm0::LayoutA, + typename Gemm0::ElementB, typename Gemm0::LayoutB, + typename Gemm0::ElementC, typename Gemm0::LayoutC, ElementCompute, + ElementAccumulator, typename Gemm0::Operator> + reference_gemm_0; + + cutlass::reference::device::Gemm< + typename Gemm1::ElementA, typename Gemm1::LayoutA, + typename Gemm1::ElementB, typename Gemm1::LayoutB, + typename Gemm1::ElementC, typename Gemm1::LayoutC, ElementCompute, + ElementAccumulator, typename Gemm1::Operator> + reference_gemm_1; + + reference_gemm_0( + problem_size, + alpha0, + tensor_A0.device_ref(), + tensor_B0.device_ref(), + beta0, + {tensor_Bias0.device_data(), typename Gemm0::LayoutC::Stride(0)}, + reference_D0.device_ref() + ); + + if(relu) { + cutlass::reference::device::TensorReLu(reference_D0.device_view()); + } + + reference_gemm_1( + problem_size, + alpha1, + tensor_A0.device_ref(), + tensor_B1.device_ref(), + beta1, + {tensor_Bias1.device_data(), typename Gemm1::LayoutC::Stride(0)}, + reference_D1.device_ref() + ); + + if(relu) { + cutlass::reference::device::TensorReLu(reference_D1.device_view()); + } + + // Wait for kernels to finish + cudaDeviceSynchronize(); + reference_D0.sync_host(); + reference_D1.sync_host(); + + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0); + + bool passed0 = cutlass::reference::host::TensorEquals( + reference_D1.host_view(), + tensor_D1.host_view()); + CHECK_TRUE(passed0); + + bool passed1 = cutlass::reference::host::TensorEquals( + reference_D1.host_view(), + tensor_D1.host_view()); + CHECK_TRUE(passed1); + if (!passed0 || !passed1) { + + std::stringstream fname; + + fname << "error_DualGemm_device_nonfused.txt"; + std::cerr << "Dumping results in " << fname.str() << "\n"; + + std::ofstream file(fname.str()); + + file + << "A0 =\n" << tensor_A0.host_view() + << "\nB0 =\n" << tensor_B0.host_view() + << "\nC0 =\n" << tensor_C0.host_view() + << "\nBias0:\n" << tensor_Bias0.host_view() << "\n" + << "\nD0 =\n" << tensor_D0.host_view() + << "\nB1 =\n" << tensor_B1.host_view() + << "\nC1 =\n" << tensor_C1.host_view() + << "\nBias1:\n" << tensor_Bias1.host_view() << "\n" + << "\n\nReference =\n" << reference_D1.host_view() + << "\nComputed =\n" << tensor_D1.host_view(); + } + return passed0 && passed1; + } +}; + +template +struct DualFusedGemmRun +{ + + using DualGemm = DualGemm_; + using ElementAccumulator = typename DualGemm::ElementAccumulator; + using ElementCompute = typename DualGemm::DualGemmKernel::Epilogue0::OutputOp::ElementCompute; + using EpilogueOutputOp2 = typename DualGemm::EpilogueOutputOp2; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + cutlass::Distribution::Kind init_Scale; + cutlass::Distribution::Kind init_Bias; + uint64_t seed; + + // + // Methods + // + + DualFusedGemmRun( + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_Scale_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_Bias_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), + init_Scale(init_Scale_), init_Bias(init_Bias_), seed(seed_) { } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, 2, -2, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else if (dist_kind == cutlass::Distribution::AllZeros) { + cutlass::reference::host::TensorFill(view, Element(0)); + } + else if (dist_kind == cutlass::Distribution::AllOnes) { + cutlass::reference::host::TensorFill(view, Element(1)); + } + else { + std::cerr << "Not implemented\n"; + return false; + } + + return true; + } + + + + + /// Executes one test + bool run( + cutlass::gemm::GemmCoord problem_size, + ElementCompute alpha0 = ElementCompute(1), + ElementCompute beta0 = ElementCompute(1), + ElementCompute alpha1 = ElementCompute(1), + ElementCompute beta1 = ElementCompute(1), + int batch_count = 1, + bool broadcast_b1 = false, + bool is_profiling = true, + bool relu = false, + int warm_ups = 1, + int runs = 100) { + + // + // Allocate the GEMM workspace + // + + cutlass::HostTensor< + typename DualGemm::ElementA, + typename DualGemm::LayoutA> tensor_A0( + cutlass::platform::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.k()) : + cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.k())); + + cutlass::HostTensor< + typename DualGemm::ElementB, + typename DualGemm::LayoutB0> tensor_B0( + cutlass::platform::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.k(), problem_size.n()) : + cutlass::MatrixCoord(problem_size.k(), batch_count * problem_size.n())); + + cutlass::HostTensor< + typename DualGemm::ElementC, + typename DualGemm::LayoutC> tensor_C0( + cutlass::platform::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) : + cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n())); + + cutlass::HostTensor< + typename DualGemm::ElementC, + typename DualGemm::LayoutScaleBias> tensor_Bias0({batch_count, problem_size.n()}); + + cutlass::HostTensor< + typename DualGemm::ElementC, + typename DualGemm::LayoutC> tensor_D0( + cutlass::platform::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) : + cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n())); + + cutlass::HostTensor< + typename DualGemm::ElementC, + typename DualGemm::LayoutC> reference_D0( + cutlass::platform::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) : + cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n())); + + cutlass::HostTensor< + typename DualGemm::ElementB, + typename DualGemm::LayoutB1> tensor_B1( + cutlass::platform::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.k(), problem_size.n()) : + cutlass::MatrixCoord(problem_size.k(), batch_count * problem_size.n())); + if (broadcast_b1) { + tensor_B1.resize({problem_size.k(), batch_count}); + } + + cutlass::HostTensor< + typename DualGemm::ElementC, + typename DualGemm::LayoutC> tensor_C1( + cutlass::platform::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) : + cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n())); + + cutlass::HostTensor< + typename DualGemm::ElementC, + typename DualGemm::LayoutScaleBias> tensor_Bias1({batch_count, problem_size.n()}); + + cutlass::HostTensor< + typename DualGemm::ElementC, + typename DualGemm::LayoutC> tensor_D1( + cutlass::platform::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) : + cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n())); + + cutlass::HostTensor< + typename DualGemm::ElementC, + typename DualGemm::LayoutC> tensor_D2( + cutlass::platform::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) : + cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n())); + + cutlass::HostTensor< + typename DualGemm::ElementC, + typename DualGemm::LayoutC> reference_D1( + cutlass::platform::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) : + cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n())); + + cutlass::HostTensor< + typename DualGemm::ElementC, + typename DualGemm::LayoutC> reference_D2( + cutlass::platform::is_same::value ? + cutlass::MatrixCoord(batch_count * problem_size.m(), problem_size.n()) : + cutlass::MatrixCoord(problem_size.m(), batch_count * problem_size.n())); + + CHECK_TRUE(initialize_tensor(tensor_A0.host_view(), init_A, seed + 2019)); + CHECK_TRUE(initialize_tensor(tensor_B0.host_view(), init_B, seed + 2118)); + CHECK_TRUE(initialize_tensor(tensor_C0.host_view(), init_C, seed + 2017)); + CHECK_TRUE(initialize_tensor(tensor_Bias0.host_view(), init_Bias, seed + 2011)); + CHECK_TRUE(initialize_tensor(tensor_B1.host_view(), init_B, seed + 2113)); + CHECK_TRUE(initialize_tensor(tensor_C1.host_view(), init_C, seed + 2015)); + CHECK_TRUE(initialize_tensor(tensor_Bias1.host_view(), init_Bias, seed + 2012)); + + cutlass::reference::host::TensorFill( + tensor_D0.host_view()); + cutlass::reference::host::TensorFill( + tensor_D1.host_view()); + cutlass::reference::host::TensorFill( + tensor_D2.host_view()); + cutlass::reference::host::TensorFill( + reference_D0.host_view()); + cutlass::reference::host::TensorFill( + reference_D1.host_view()); + cutlass::reference::host::TensorFill( + reference_D2.host_view()); + + tensor_A0.sync_device(); + tensor_B0.sync_device(); + tensor_C0.sync_device(); + tensor_Bias0.sync_device(); + tensor_B1.sync_device(); + tensor_C1.sync_device(); + tensor_Bias1.sync_device(); + tensor_D0.sync_device(); + tensor_D1.sync_device(); + tensor_D2.sync_device(); + reference_D0.sync_device(); + reference_D1.sync_device(); + reference_D2.sync_device(); + + // + // Batch strides (irrelevant when batch_count == 1) + // + + int64_t batch_stride_A = problem_size.m() * problem_size.k(); + int64_t batch_stride_B0 = problem_size.k() * problem_size.n(); + int64_t batch_stride_B1 = problem_size.k() * problem_size.n(); + if (broadcast_b1) { + // B1 is a (column) vector + batch_stride_B1 = problem_size.k(); + } + int64_t batch_stride_Bias = problem_size.n(); + int64_t batch_stride_D = problem_size.m() * problem_size.n(); + + // + // Initialize the GEMM operator + // + + int split_k_slices = DualGemm::kSplitKSerial ? 2 : 1; + typename cutlass::TensorRef nullptr_ref{}; + decltype(nullptr_ref) ref_B0, ref_B1; + if (beta0 != ElementCompute(0)) { + ref_B0 = {tensor_Bias0.device_data(), typename DualGemm::LayoutC::Stride(0)}; + } + if (beta1 != ElementCompute(0)) { + ref_B1 = {tensor_Bias1.device_data(), typename DualGemm::LayoutC::Stride(0)}; + } + typename DualGemm::Arguments arguments{ + (batch_count > 1 ? + cutlass::gemm::DualGemmMode::kBatched : + cutlass::gemm::DualGemmMode::kGemm), + problem_size, + tensor_A0.device_ref(), + tensor_B0.device_ref(), + ref_B0, + DualGemm::kStoreD0 ? tensor_D0.device_ref() : nullptr_ref, + (broadcast_b1 ? + typename DualGemm::TensorRefB1(tensor_B1.device_data(), 0) : + tensor_B1.device_ref()), + ref_B1, + DualGemm::kStoreD1 ? tensor_D1.device_ref() : nullptr_ref, + tensor_D2.device_ref(), + {alpha0, beta0}, + {alpha1, beta1}, + {}, + split_k_slices, + batch_count, + batch_stride_A, + batch_stride_B0, + batch_stride_B1, + batch_stride_Bias, + batch_stride_D, + }; + + // + // Run the GEMM + // + + DualGemm b2b_gemm_op; + + cutlass::device_memory::allocation workspace(b2b_gemm_op.get_workspace_size(arguments)); + + cutlass::Status status = b2b_gemm_op.can_implement(arguments); + + CUTLASS_CHECK(status); + + status = b2b_gemm_op.initialize(arguments, workspace.get()); + + CUTLASS_CHECK(status); + + for(int i = 0; i < warm_ups; i++) { + status = b2b_gemm_op(); + CUTLASS_CHECK(status); + } + + if (is_profiling) { + // + // Profile the GEMM + // + + cudaEvent_t start, stop; + cudaEventCreate(&start); + cudaEventCreate(&stop); + + cudaEventRecord(start); + + for(int i = 0; i < runs; i++) { + status = b2b_gemm_op(); + CUTLASS_CHECK(status); + } + + cudaEventRecord(stop); + cudaDeviceSynchronize(); + float gemmTime; + cudaEventElapsedTime(&gemmTime, start, stop); + std::cout << "Fusion time " << gemmTime / (float)runs << " ms\n"; + } + + tensor_D0.sync_host(); + tensor_D1.sync_host(); + tensor_D2.sync_host(); + + // + // Verify + // + + using GemmUniversal0 = cutlass::gemm::device::GemmUniversal< + typename DualGemm::ElementA, typename DualGemm::LayoutA, + typename DualGemm::ElementB, typename DualGemm::LayoutB0, + typename DualGemm::ElementC, typename DualGemm::LayoutC, + ElementAccumulator + >; + + GemmUniversal0 reference_gemm0; + + typename GemmUniversal0::Arguments args0 { + (batch_count > 1 ? + cutlass::gemm::GemmUniversalMode::kBatched : + cutlass::gemm::GemmUniversalMode::kGemm), + problem_size, + batch_count, + {alpha0, beta0}, + tensor_A0.device_data(), + tensor_B0.device_data(), + tensor_Bias0.device_data(), + reference_D0.device_data(), + batch_stride_A, + batch_stride_B0, + batch_stride_Bias, + batch_stride_D, + tensor_A0.stride(0), + tensor_B0.stride(0), + 0, // zero stride for the bias vector + reference_D0.stride(0), + }; + + status = reference_gemm0.can_implement(args0); + CUTLASS_CHECK(status); + status = reference_gemm0(args0); + CUTLASS_CHECK(status); + + using GemmUniversal1 = cutlass::gemm::device::GemmUniversal< + typename DualGemm::ElementA, typename DualGemm::LayoutA, + typename DualGemm::ElementB, typename DualGemm::LayoutB1, + typename DualGemm::ElementC, typename DualGemm::LayoutC, + ElementAccumulator + >; + + GemmUniversal1 reference_gemm1; + + typename GemmUniversal1::Arguments args1 { + (batch_count > 1 ? + cutlass::gemm::GemmUniversalMode::kBatched : + cutlass::gemm::GemmUniversalMode::kGemm), + problem_size, + batch_count, + {alpha1, beta1}, + tensor_A0.device_data(), + tensor_B1.device_data(), + tensor_Bias1.device_data(), + reference_D1.device_data(), + batch_stride_A, + batch_stride_B1, + batch_stride_Bias, + batch_stride_D, + tensor_A0.stride(0), + (broadcast_b1 ? 0 : tensor_B1.stride(0)), + 0, // zero stride for the bias vector + reference_D1.stride(0), + }; + + status = reference_gemm1.can_implement(args1); + CUTLASS_CHECK(status); + status = reference_gemm1(args1); + CUTLASS_CHECK(status); + + if(relu) { + cutlass::reference::device::TensorReLu(reference_D0.device_view()); + cutlass::reference::device::TensorReLu(reference_D1.device_view()); + } + + TensorEpilogueForEach(reference_D0.device_view(), reference_D1.device_view(), reference_D2.device_view()); + cudaDeviceSynchronize(); + reference_D0.sync_host(); + reference_D1.sync_host(); + reference_D2.sync_host(); + + CHECK_GT(cutlass::reference::host::TensorNorm(reference_D0.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(reference_D1.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D2.host_view()), 0); + CHECK_GT(cutlass::reference::host::TensorNorm(reference_D2.host_view()), 0); + + bool passed_out0 = true; + if (DualGemm::kStoreD0) { + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D0.host_view()), 0); + passed_out0 = cutlass::reference::host::TensorEquals( + reference_D0.host_view(), + tensor_D0.host_view()); + } + CHECK_TRUE(passed_out0); + + bool passed_out1 = true; + if (DualGemm::kStoreD1) { + CHECK_GT(cutlass::reference::host::TensorNorm(tensor_D1.host_view()), 0); + passed_out1 = cutlass::reference::host::TensorEquals( + reference_D1.host_view(), + tensor_D1.host_view()); + } + CHECK_TRUE(passed_out1); + + bool passed_out2 = cutlass::reference::host::TensorEquals( + reference_D2.host_view(), + tensor_D2.host_view()); + CHECK_TRUE(passed_out2); + + bool passed = passed_out0 && passed_out1 && passed_out2; + if (!passed) + { + std::stringstream fname; + + fname << "error_DualGemm_device_fused.txt"; + std::cerr << "Dumping results in " << fname.str() << "\n"; + + std::ofstream file(fname.str()); + + file + << "A0 =\n" << tensor_A0.host_view() + << "\nB0 =\n" << tensor_B0.host_view() + << "\nC0 =\n" << tensor_C0.host_view() + << "\nBias0:\n" << tensor_Bias0.host_view() << "\n" + << "\nB1 =\n" << tensor_B1.host_view() + << "\nC1 =\n" << tensor_C1.host_view() + << "\nBias1:\n" << tensor_Bias1.host_view() << "\n" + << "\n\nReference0 =\n" << reference_D0.host_view() + << "\nComputed0 =\n" << tensor_D0.host_view() + << "\n\nReference1 =\n" << reference_D1.host_view() + << "\nComputed1 =\n" << tensor_D1.host_view() + << "\n\nReference2 =\n" << reference_D2.host_view() + << "\nComputed2 =\n" << tensor_D2.host_view(); + } + //std::cout << "A0 " << tensor_A0.host_view() << std::endl; + // std::cout << "reference_D0 " << reference_D0.host_view() << std::endl; + // std::cout << "reference_D1 " << reference_D1.host_view() << std::endl; + // std::cout << "reference_D2 " << reference_D2.host_view() << std::endl; + //std::cout << "reference_D0 " << reference_D0.host_view() << std::endl; + return passed; + } + +}; + +//////////////////////////////////////////////////////////////////////////////// diff --git a/examples/45_dual_gemm/kernel/dual_gemm.h b/examples/45_dual_gemm/kernel/dual_gemm.h new file mode 100644 index 0000000000..417f6ff25c --- /dev/null +++ b/examples/45_dual_gemm/kernel/dual_gemm.h @@ -0,0 +1,545 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a pipelined GEMM kernel. Does not compute batching or support split-K. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/semaphore.h" + +#include "../threadblock/dual_mma_multistage.h" +#include "../threadblock/dual_epilogue.h" +#include "../dual_gemm_common.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename DualMma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue0_, ///! Epilogue + typename Epilogue1_, ///! Epilogue + typename OutputOp2_, ///! Epilogue + typename ThreadblockSwizzle_, ///! Threadblock swizzling function + bool SplitKSerial, ///! If true, code supporting split-K via serial reduction is enabled. + bool StoreD0, + bool StoreD1 +> +struct DualGemm { + + using DualMma = DualMma_; + + using Epilogue0 = Epilogue0_; + using Epilogue1 = Epilogue1_; + using OutputOp0 = typename Epilogue0::OutputOp; + using OutputOp1 = typename Epilogue1::OutputOp; + using OutputOp2 = OutputOp2_; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static constexpr bool kStoreD0 = StoreD0; + static constexpr bool kStoreD1 = StoreD1; + + using DualEpilogue = cutlass::epilogue::threadblock::DualEpilogue< + typename Epilogue0::Shape, + typename Epilogue0::WarpMmaOperator, + Epilogue0::kPartitionsK, + typename Epilogue0::OutputTileIterator, + typename Epilogue0::AccumulatorFragmentIterator, + typename Epilogue0::WarpTileIterator, + typename Epilogue0::SharedLoadIterator, + OutputOp0, + OutputOp1, + OutputOp2, + typename Epilogue0::Padding, + kStoreD0, + kStoreD1, + Epilogue0::kFragmentsPerIteration, + true // IterationsUnroll + >; + + using ElementA = typename DualMma::IteratorA::Element; + using ElementB = typename DualMma::IteratorB0::Element; + using ElementC = typename DualEpilogue::OutputTileIterator::Element; + + static bool const kSplitKSerial = SplitKSerial; + static_assert(!kSplitKSerial || (kStoreD0 && kStoreD1), + "Split-K serial requires buffers for D0/D1 for reduction"); + + /// Warp count (concept: GemmShape) + using WarpCount0 = typename DualMma::WarpCount; + static int const kThreadCount = 32 * WarpCount0::kCount; + + /// Parameters structure + struct Params { + DualGemmMode mode; + cutlass::gemm::GemmCoord problem_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + int swizzle_log_tile; + + // Mma0 + typename DualMma::IteratorA::Params params_A0; + typename DualMma::IteratorA::TensorRef ref_A0; + typename DualMma::IteratorB0::Params params_B0; + typename DualMma::IteratorB0::TensorRef ref_B0; + typename Epilogue0::OutputTileIterator::Params params_C0; + typename Epilogue0::OutputTileIterator::TensorRef ref_C0; + typename Epilogue0::OutputTileIterator::Params params_D0; + typename Epilogue0::OutputTileIterator::TensorRef ref_D0; + typename OutputOp0::Params output_op_0; + + // Mma1 + typename DualMma::IteratorB1::Params params_B1; + typename DualMma::IteratorB1::TensorRef ref_B1; + typename Epilogue1::OutputTileIterator::Params params_C1; + typename Epilogue1::OutputTileIterator::TensorRef ref_C1; + typename Epilogue1::OutputTileIterator::Params params_D1; + typename Epilogue1::OutputTileIterator::TensorRef ref_D1; + typename OutputOp1::Params output_op_1; + + typename Epilogue1::OutputTileIterator::Params params_D2; + typename Epilogue1::OutputTileIterator::TensorRef ref_D2; + typename OutputOp2::Params output_op_2; + + int *semaphore; + int gemm_k_size; + + int64_t batch_stride_A; + int64_t batch_stride_B0; + int64_t batch_stride_B1; + int64_t batch_stride_C; + int64_t batch_stride_D; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): swizzle_log_tile(0), semaphore(0), gemm_k_size(0) { } + + CUTLASS_HOST_DEVICE + Params( + DualGemmMode mode, + cutlass::gemm::GemmCoord const & problem_size, + cutlass::gemm::GemmCoord const & grid_tiled_shape, + // Mma0: D0 = A @ B0 + C0 + typename DualMma::IteratorA::TensorRef ref_A0, + typename DualMma::IteratorB0::TensorRef ref_B0, + typename Epilogue0::OutputTileIterator::TensorRef ref_C0, + typename Epilogue0::OutputTileIterator::TensorRef ref_D0, + // Mma1: D1 = A @ B1 + C1 + typename DualMma::IteratorB1::TensorRef ref_B1, + typename Epilogue1::OutputTileIterator::TensorRef ref_C1, + typename Epilogue1::OutputTileIterator::TensorRef ref_D1, + + typename Epilogue1::OutputTileIterator::TensorRef ref_D2, + typename OutputOp0::Params output_op_0 = typename OutputOp0::Params(), + typename OutputOp1::Params output_op_1 = typename OutputOp1::Params(), + typename OutputOp2::Params output_op_2 = typename OutputOp2::Params(), + int *workspace = nullptr, + int64_t batch_stride_A = 1, + int64_t batch_stride_B0 = 1, + int64_t batch_stride_B1 = 1, + int64_t batch_stride_C = 1, + int64_t batch_stride_D = 1 + ): + mode(mode), + problem_size(problem_size), + grid_tiled_shape(grid_tiled_shape), + swizzle_log_tile(ThreadblockSwizzle().get_log_tile(grid_tiled_shape)), + // Mma0 + params_A0(ref_A0.layout()), + ref_A0(ref_A0), + params_B0(ref_B0.layout()), + ref_B0(ref_B0), + params_C0(ref_C0.layout()), + ref_C0(ref_C0), + params_D0(ref_D0.layout()), + ref_D0(ref_D0), + // Mma1 + params_B1(ref_B1.layout()), + ref_B1(ref_B1), + params_C1(ref_C1.layout()), + ref_C1(ref_C1), + params_D1(ref_D1.layout()), + ref_D1(ref_D1), + params_D2(ref_D2.layout()), + ref_D2(ref_D2), + output_op_0(output_op_0), + output_op_1(output_op_1), + output_op_2(output_op_2), + batch_stride_A(batch_stride_A), + batch_stride_B0(batch_stride_B0), + batch_stride_B1(batch_stride_B1), + batch_stride_C(batch_stride_C), + batch_stride_D(batch_stride_D) { + + int total_gemm_k_iterations = (problem_size.k() + DualMma::Shape::kK - 1) / DualMma::Shape::kK; + int gemm_k_iterations = (total_gemm_k_iterations + grid_tiled_shape.k() - 1) / grid_tiled_shape.k(); + gemm_k_size = gemm_k_iterations * DualMma::Shape::kK; + + semaphore = workspace; + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename DualMma::SharedStorage main_loop; + typename DualEpilogue::SharedStorage epilogue; + }; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + DualGemm() { } + + /// Determines whether kernel satisfies alignment + static Status can_implement( + cutlass::gemm::GemmCoord const & problem_size, + typename DualMma::IteratorA::TensorRef ref_A0, + typename DualMma::IteratorB0::TensorRef ref_B0, + typename Epilogue0::OutputTileIterator::TensorRef ref_C0, + typename Epilogue0::OutputTileIterator::TensorRef ref_D0, + typename DualMma::IteratorB1::TensorRef ref_B1, + typename Epilogue1::OutputTileIterator::TensorRef ref_C1, + typename Epilogue1::OutputTileIterator::TensorRef ref_D1, + typename Epilogue1::OutputTileIterator::TensorRef ref_D2) { + + static int const kAlignmentA = DualMma::IteratorA::AccessType::kElements; + static int const kAlignmentB = DualMma::IteratorB0::AccessType::kElements; + static int const kAlignmentC = Epilogue0::OutputTileIterator::kElementsPerAccess; + + if (!TensorRef_aligned(ref_A0, kAlignmentA)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_B0, kAlignmentB)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_C0, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_D0, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_B1, kAlignmentB)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_C1, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_D1, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + if (!TensorRef_aligned(ref_D2, kAlignmentC)) { + return Status::kErrorMisalignedOperand; + } + + return Status::kSuccess; + } + + /// Executes one GEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_offset.m() || + params.grid_tiled_shape.n() <= threadblock_tile_offset.n()) { + + return; + } + + int offset_k = 0; + int problem_size_k = params.problem_size.k(); + + ElementA *ptr_A0 = static_cast(params.ref_A0.data()); + ElementB *ptr_B0 = static_cast(params.ref_B0.data()); + ElementB *ptr_B1 = static_cast(params.ref_B1.data()); + + // + // Fetch pointers based on mode. + // + if (params.mode == DualGemmMode::kGemm) { + if (threadblock_tile_offset.k() + 1 < params.grid_tiled_shape.k()) { + problem_size_k = (threadblock_tile_offset.k() + 1) * params.gemm_k_size; + } + + offset_k = threadblock_tile_offset.k() * params.gemm_k_size; + } + else if (params.mode == DualGemmMode::kBatched) { + ptr_A0 += threadblock_tile_offset.k() * params.batch_stride_A; + ptr_B0 += threadblock_tile_offset.k() * params.batch_stride_B0; + ptr_B1 += threadblock_tile_offset.k() * params.batch_stride_B1; + } + + // Compute initial location in logical coordinates + cutlass::MatrixCoord tb_offset_A0{ + threadblock_tile_offset.m() * DualMma::Shape::kM, + offset_k, + }; + + cutlass::MatrixCoord tb_offset_B0{ + offset_k, + threadblock_tile_offset.n() * DualMma::Shape::kN + }; + + cutlass::MatrixCoord tb_offset_B1{ + offset_k, + threadblock_tile_offset.n() * DualMma::Shape::kN + }; + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename DualMma::IteratorA iterator_A0( + params.params_A0, + ptr_A0, + {params.problem_size.m(), problem_size_k}, + thread_idx, + tb_offset_A0); + + typename DualMma::IteratorB0 iterator_B0( + params.params_B0, + ptr_B0, + {problem_size_k, params.problem_size.n()}, + thread_idx, + tb_offset_B0); + + typename DualMma::IteratorB1 iterator_B1( + params.params_B1, + ptr_B1, + {problem_size_k, params.problem_size.n()}, + thread_idx, + tb_offset_B1); + + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + + + // Construct thread-scoped matrix multiply + typename DualMma::FragmentC accum0; + typename DualMma::FragmentC accum1; + accum0.clear(); + accum1.clear(); + + // Compute threadblock-scoped matrix multiply-add + int gemm_k_iterations = (problem_size_k - offset_k + DualMma::Shape::kK - 1) / DualMma::Shape::kK; + + DualMma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + if (!kSplitKSerial || gemm_k_iterations > 0) { + // Compute threadblock-scoped matrix multiply-add + mma(gemm_k_iterations, + accum0, accum1, + iterator_A0, iterator_B0, iterator_B1, + accum0, accum1); + } + + // + // Epilogue + // + + OutputOp0 output_op_0(params.output_op_0); + OutputOp1 output_op_1(params.output_op_1); + OutputOp2 output_op_2(params.output_op_2); + + // + // Masked tile iterators constructed from members + // + + threadblock_tile_offset = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + //assume identity swizzle + MatrixCoord threadblock_offset( + threadblock_tile_offset.m() * DualMma::Shape::kM, + threadblock_tile_offset.n() * DualMma::Shape::kN + ); + + int block_idx = threadblock_tile_offset.m() + threadblock_tile_offset.n() * params.grid_tiled_shape.m(); + + ElementC *ptr_C0 = static_cast(params.ref_C0.data()); + ElementC *ptr_C1 = static_cast(params.ref_C1.data()); + ElementC *ptr_D0 = static_cast(params.ref_D0.data()); + ElementC *ptr_D1 = static_cast(params.ref_D1.data()); + ElementC *ptr_D2 = static_cast(params.ref_D2.data()); + + // Construct the semaphore. + Semaphore semaphore(params.semaphore + block_idx, thread_idx); + + if (params.mode == DualGemmMode::kGemm) { + // If performing a reduction via split-K, fetch the initial synchronization + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); + + // Indicate which position in a serial reduction the output operator is currently updating + output_op_0.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + output_op_1.set_k_partition(threadblock_tile_offset.k(), params.grid_tiled_shape.k()); + } + } + else if (params.mode == DualGemmMode::kBatched) { + ptr_C0 += threadblock_tile_offset.k() * params.batch_stride_C; + ptr_C1 += threadblock_tile_offset.k() * params.batch_stride_C; + ptr_D0 += threadblock_tile_offset.k() * params.batch_stride_D; + ptr_D1 += threadblock_tile_offset.k() * params.batch_stride_D; + ptr_D2 += threadblock_tile_offset.k() * params.batch_stride_D; + } + + // Tile iterator loading from source tensor. + typename Epilogue0::OutputTileIterator iterator_C0( + params.params_C0, + ptr_C0, + params.problem_size.mn(), + thread_idx, + threadblock_offset + ); + typename Epilogue1::OutputTileIterator iterator_C1( + params.params_C1, + ptr_C1, + params.problem_size.mn(), + thread_idx, + threadblock_offset + ); + + // Tile iterator writing to destination tensor. + typename Epilogue0::OutputTileIterator iterator_D0( + params.params_D0, + ptr_D0, + params.problem_size.mn(), + thread_idx, + threadblock_offset + ); + typename Epilogue1::OutputTileIterator iterator_D1( + params.params_D1, + ptr_D1, + params.problem_size.mn(), + thread_idx, + threadblock_offset + ); + typename Epilogue1::OutputTileIterator iterator_D2( + params.params_D2, + ptr_D2, + params.problem_size.mn(), + thread_idx, + threadblock_offset + ); + + DualEpilogue epilogue( + shared_storage.epilogue, + thread_idx, + warp_idx, + lane_idx); + + // Wait on the semaphore - this latency may have been covered by iterator construction + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + + // For subsequent threadblocks, the source matrix is held in the 'D' tensor. + if (threadblock_tile_offset.k()) { + iterator_C0 = iterator_D0; + iterator_C1 = iterator_D1; + } + + semaphore.wait(threadblock_tile_offset.k()); + + __threadfence(); + } + + // Execute the epilogue operator to update the destination tensor. + typename Epilogue0::OutputTileIterator source_iters[] = { + iterator_C0, iterator_C1 + }; + const bool writeToD2 = (!kSplitKSerial || params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1); + epilogue( + output_op_0, output_op_1, output_op_2, + iterator_D0, iterator_D1, iterator_D2, + accum0, accum1, + source_iters, + writeToD2 + ); + + // + // Release the semaphore + // + + if (kSplitKSerial && params.grid_tiled_shape.k() > 1) { + + int lock = 0; + if (params.grid_tiled_shape.k() == threadblock_tile_offset.k() + 1) { + + // The final threadblock resets the semaphore for subsequent grids. + lock = 0; + } + else { + // Otherwise, the semaphore is incremented + lock = threadblock_tile_offset.k() + 1; + } + + __threadfence(); + semaphore.release(lock); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace gemm +} // namespace cutlass + diff --git a/examples/45_dual_gemm/test_run.h b/examples/45_dual_gemm/test_run.h new file mode 100644 index 0000000000..4a58a3a16c --- /dev/null +++ b/examples/45_dual_gemm/test_run.h @@ -0,0 +1,95 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + + +#include + +// Run tests on GPUs + +int testRun(int arch, std::vector & test_funcs, const std::string & test_name) { + + bool supported = false; + + int arch_major = arch / 10; + int arch_minor = arch - arch / 10 * 10; + + if(arch_major >= 8) { + // Ampere Tensor Core operations exposed with mma.sync are first available in CUDA 11.0. + // + // CUTLASS must be compiled with CUDA 11 Toolkit to run Conv2dFprop examples. + if (__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 0)) { + supported = true; + } + } + else if(arch_major >= 7) { + // Turing Tensor Core operations exposed with mma.sync are first available in CUDA 10.2. + // + // CUTLASS must be compiled with CUDA 10.2 Toolkit to run these examples. + if (__CUDACC_VER_MAJOR__ > 10 || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2)) { + supported = true; + } + } + + cudaDeviceProp props; + + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if (props.major < arch_major || (props.major == arch_major && props.minor < arch_minor) ) { + supported = false; + } + + if (!supported) { + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + std::cout << "This example isn't supported on current architecture" << std::endl; + return 0; + } + + bool pass = true; + + std::cout << "Device: " << props.name << std::endl; + std::cout << "Arch: SM" << arch << std::endl; + std::cout << "Test: " << test_name << std::endl; + for(auto func : test_funcs) { + pass &= func(); + } + + + if(pass) + return 0; + else + return -1; + +} + diff --git a/examples/45_dual_gemm/thread/left_silu_and_mul.h b/examples/45_dual_gemm/thread/left_silu_and_mul.h new file mode 100644 index 0000000000..47043267f5 --- /dev/null +++ b/examples/45_dual_gemm/thread/left_silu_and_mul.h @@ -0,0 +1,150 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing linear combination operations used by epilogues. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/functional.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/epilogue/thread/scale_type.h" +#include "cutlass/epilogue/thread/linear_combination_params.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Applies a linear combination operator to an array of elements. +/// +/// D = alpha * accumulator + beta * source + uniform +/// +template < + typename ElementOutput_, ///< Data type used to load and store tensors + int Count, ///< Number of elements computed per operation. + ///< Usually it is 128/sizeof_bits, + ///< but we use 64 or 32 sometimes when there are not enough data to store + typename ElementAccumulator_ = ElementOutput_, ///< Accumulator data type + typename ElementCompute_ = ElementOutput_, ///< Data type used to compute linear combination + FloatRoundStyle Round = FloatRoundStyle::round_to_nearest +> +class LeftSiLUAndMul { +public: + + using ElementOutput = ElementOutput_; + using ElementAccumulator = ElementAccumulator_; + using ElementCompute = ElementCompute_; + + static int const kCount = Count; + using FragmentOutput = Array; + using FragmentAccumulator = Array; + using ComputeFragment = Array; + + static FloatRoundStyle const kRound = Round; + + struct Params{}; + +private: + + // + // Data members + // + + ElementCompute alpha_; + ElementCompute beta_; + +public: + + /// Constructs the function object, possibly loading from pointers in host memory + CUTLASS_HOST_DEVICE + LeftSiLUAndMul(Params const &/*params*/) {} + + /// Returns true if source is needed + CUTLASS_HOST_DEVICE + bool is_source_needed() const { + return true; + } + + /// Functionally required for serial reduction in the epilogue + CUTLASS_HOST_DEVICE + void set_k_partition(int k_partition, int k_partition_count) { + assert(false); + } + + /// Computes linear scaling: D = alpha * accumulator + beta * source + CUTLASS_HOST_DEVICE + FragmentOutput operator()( + FragmentAccumulator const &lhs, + FragmentAccumulator const &rhs) const { + + // Convert source to interal compute numeric type + NumericArrayConverter accumulator_to_compute; + + // Convert to destination numeric type + NumericArrayConverter compute_to_output; + + ComputeFragment converted_lhs = accumulator_to_compute(lhs); + ComputeFragment converted_rhs = accumulator_to_compute(rhs); + + cutlass::epilogue::thread::SiLu silu; + cutlass::multiplies mul; + auto silu_lhs = silu(converted_lhs); + return compute_to_output(mul(silu_lhs, converted_rhs)); + } + + CUTLASS_HOST_DEVICE + ElementOutput operator()( + ElementAccumulator const& lhs, + ElementAccumulator const& rhs + ) const { + ElementCompute convert_lhs(lhs); + ElementCompute convert_rhs(rhs); + cutlass::epilogue::thread::SiLu silu; + cutlass::multiplies mul; + auto silu_lhs = silu(convert_lhs); + return ElementOutput(mul(silu_lhs, convert_rhs)); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace thread +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/45_dual_gemm/threadblock/dual_epilogue.h b/examples/45_dual_gemm/threadblock/dual_epilogue.h new file mode 100644 index 0000000000..3ef1c6d33c --- /dev/null +++ b/examples/45_dual_gemm/threadblock/dual_epilogue.h @@ -0,0 +1,426 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Epilogue for threadblock scoped GEMMs using Tensor Ops. + + The epilogue rearranges the result of a matrix product through shared memory to match canonical + tensor layouts in global memory. Epilogues support conversion and reduction operations. + +*/ + +#pragma once + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" +#include "cutlass/array.h" +#include "cutlass/layout/vector.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/aligned_buffer.h" +#include "cutlass/functional.h" + +#include "cutlass/gemm/gemm.h" + +#include "cutlass/transform/pitch_linear_thread_map.h" +#include "cutlass/transform/threadblock/regular_tile_iterator.h" + +#include "cutlass/epilogue/threadblock/epilogue_base.h" +#include "cutlass/epilogue/threadblock/predicated_tile_iterator.h" +#include "cutlass/numeric_types.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Epilogue operator +template < + typename Shape_, ///< Shape of threadblock tile (concept: GemmShape) + typename WarpMmaOperator_, ///< Warp-level MMA operator (concept: gemm::warp::MmaTensorOp) + int PartitionsK, ///< Number of partitions of the K dimension + typename OutputTileIterator_, ///< Tile iterator reading and writing output tensors + typename AccumulatorFragmentIterator_, ///< Fragment iterator selecting accumulators + typename WarpTileIterator_, ///< Warp-scoped tile iterator writing accumulators to SMEM + typename SharedLoadIterator_, ///< Threadblock-scoped tile iterator loading from SMEM + ///< Output operator + typename OutputOp0_, + typename OutputOp1_, + typename OutputOp2_, + typename Padding_, ///< Padding added to SMEM allocation to avoid bank conflicts (concept: MatrixShape) + bool StoreD0 = true, + bool StoreD1 = true, + int FragmentsPerPartition = 1, ///< Used to coarsten the epilogue granularity + int IterationsUnroll = ///< Used to reduce binary size when epilogue op is large + (!IsEpilogueFunctorHeavy::value) +> +class DualEpilogue { + +public: + + using Base = EpilogueBase< + Shape_, + typename WarpMmaOperator_::Shape, + PartitionsK, + AccumulatorFragmentIterator_, + WarpTileIterator_, + Padding_, + FragmentsPerPartition>; + + using Shape = Shape_; + using WarpMmaOperator = WarpMmaOperator_; + static int const kPartitionsK = PartitionsK; + static bool constexpr kStoreD0 = StoreD0; + static bool constexpr kStoreD1 = StoreD1; + using OutputTileIterator = OutputTileIterator_; + using AccumulatorFragmentIterator = AccumulatorFragmentIterator_; + using WarpTileIterator = WarpTileIterator_; + using SharedLoadIterator = SharedLoadIterator_; + using OutputOp0 = OutputOp0_; + using OutputOp1 = OutputOp1_; + using OutputOp2 = OutputOp2_; + using Padding = Padding_; + + using Layout = layout::RowMajor; + using LongIndex = typename Layout::LongIndex; + + /// The complete warp-level accumulator tile + using AccumulatorTile = typename Base::AccumulatorTile; + + /// Accumulator element + using ElementAccumulator = typename WarpTileIterator::Element; + + /// Output element + using ElementOutput = typename OutputTileIterator::Element; + + /// Output access size + static int const kElementsPerAccess = OutputTileIterator::kElementsPerAccess; + + /// Tensor reference to destination tensor + using TensorRef = typename OutputTileIterator::TensorRef; + + /// Tensor reference to sync tensor + using SyncTensorRef = typename cutlass::TensorRef; + + /// Const tensor reference to source tensor + using ConstTensorRef = typename OutputTileIterator::ConstTensorRef; + + /// Array type used to output + using OutputAccessType = Array< + typename OutputTileIterator::Element, OutputTileIterator::kElementsPerAccess>; + + /// Array type used by output functor + using AccumulatorAccessType = Array; + + /// Number of warps + using WarpCount = typename Base::WarpCount; + + struct SharedStorage { + using Element = typename WarpTileIterator::Element; + + /// Tensor reference to shared memory allocation + using TensorRef = typename WarpTileIterator::TensorRef; + + /// Logical shape of the shared memory tile written to by all warps. + using Shape = typename Base::Shape; + + /// Shape of the shared memory allocation for the epilogue + using StorageShape = typename Base::SharedStorage::StorageShape; + + // + // Data members + // + + AlignedBuffer storage[2]; + + // + // Methods + // + + /// Returns a tensor reference to the shared memory buffer + CUTLASS_DEVICE + TensorRef reference(int i) { + return TensorRef( + storage[i].data(), + Layout::packed({StorageShape::kRow, StorageShape::kColumn})); + } + }; + + static int constexpr kSmemTiles = Base::kFragmentsPerIteration > 1 ? Base::kFragmentsPerIteration : kPartitionsK; + static int constexpr kSmemPointerOffset = SharedStorage::StorageShape::kCount / kSmemTiles; + +public: + + static_assert(SharedLoadIterator::Fragment::kElements == OutputTileIterator::Fragment::kElements, + "Mismatch between shared load iterator and output tile iterator."); + + static_assert(OutputTileIterator::kElementsPerAccess, "OutputTileIterator::kElementsPerAccess must not be zero."); + + static_assert(!(OutputTileIterator::Fragment::kElements % OutputTileIterator::kElementsPerAccess), + "Divisibility"); + +private: + + /// Loads fragment from shared memory aligned with output tensor + SharedLoadIterator shared_load_iterator0_; + SharedLoadIterator shared_load_iterator1_; + + /// Stores a warp's fragment of accumulators to SMEM + WarpTileIterator warp_tile_iterator0_; + WarpTileIterator warp_tile_iterator1_; + +public: + + /// Constructor + CUTLASS_DEVICE + DualEpilogue( + SharedStorage &shared_storage, ///< Shared storage object + int thread_idx, ///< ID of a thread within the threadblock + int warp_idx, ///< ID of warp within threadblock + int lane_idx ///< Id of thread within warp + ): + shared_load_iterator0_(shared_storage.reference(0), thread_idx), + shared_load_iterator1_(shared_storage.reference(1), thread_idx), + warp_tile_iterator0_(shared_storage.reference(0), lane_idx), + warp_tile_iterator1_(shared_storage.reference(1), lane_idx) + { + int warp_k = warp_idx / (WarpCount::kM * WarpCount::kN); + int warp_mn = warp_idx % (WarpCount::kM * WarpCount::kN); + int warp_m = warp_mn % WarpCount::kM; + int warp_n = warp_mn / WarpCount::kM; + + MatrixCoord warp_offset{warp_k * WarpCount::kM + warp_m, warp_n}; + + warp_tile_iterator0_.add_tile_offset(warp_offset); + warp_tile_iterator1_.add_tile_offset(warp_offset); + } + + /// Streams the result to global memory + CUTLASS_DEVICE + void operator()( + OutputOp0 const &output_op0, + OutputOp1 const &output_op1, + OutputOp2 const &output_op2, + OutputTileIterator dest0, + OutputTileIterator dest1, + OutputTileIterator dest2, + AccumulatorTile const &accumulator0, + AccumulatorTile const &accumulator1, + OutputTileIterator source_iterator[2], + bool writeToD2 // true if it's the final split-k + ) { + // TODO: Implement when no source is needed + + typename OutputTileIterator::Fragment source_fragment[2]; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 2; ++i) { + source_fragment[i].clear(); + } + + // + // Iterator over warp-level accumulator fragment + // + + AccumulatorFragmentIterator accum_fragment_iterator[2] = {accumulator0, accumulator1}; + + // + // Iterate over accumulator tile + // + + #pragma unroll(IterationsUnroll ? OutputTileIterator::kIterations : 1) + for (int iter = 0; iter < OutputTileIterator::kIterations; ++iter) { + + // + // Load the source + // + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 2; ++i) { + source_iterator[i].load(source_fragment[i]); + ++source_iterator[i]; + } + + // + // Convert and store fragment + // + + __syncthreads(); + + acc2smem_source_needed>::push( + iter, accum_fragment_iterator[0], this->warp_tile_iterator0_); + acc2smem_source_needed>::push( + iter, accum_fragment_iterator[1], this->warp_tile_iterator1_); + + __syncthreads(); + + // + // Load fragments from shared memory + // + + typename SharedLoadIterator::Fragment aligned_accum_fragment0[kPartitionsK]; + typename SharedLoadIterator::Fragment aligned_accum_fragment1[kPartitionsK]; + + shared_load_iterator0_.load(aligned_accum_fragment0[0]); + shared_load_iterator1_.load(aligned_accum_fragment1[0]); + + // If the number of k-slices is > 1 - perform a reduction amongst the k-slices + if (kPartitionsK > 1) { + + plus add_fragments; + + CUTLASS_PRAGMA_UNROLL + for ( int i = 1; i < kPartitionsK; ++i) { + shared_load_iterator0_.add_pointer_offset(kSmemPointerOffset); + shared_load_iterator1_.add_pointer_offset(kSmemPointerOffset); + shared_load_iterator0_.load(aligned_accum_fragment0[i]); + shared_load_iterator1_.load(aligned_accum_fragment1[i]); + aligned_accum_fragment0[0] = add_fragments(aligned_accum_fragment0[0], aligned_accum_fragment0[i]); + aligned_accum_fragment1[0] = add_fragments(aligned_accum_fragment1[0], aligned_accum_fragment1[i]); + } + + shared_load_iterator0_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset); + shared_load_iterator1_.add_pointer_offset((1 - kPartitionsK) * kSmemPointerOffset); + } + + // + // Compute the output result + // + + typename OutputTileIterator::Fragment output_fragment[3]; + + apply_output_operator_(output_fragment, + output_op0, output_op1, output_op2, + aligned_accum_fragment0[0], aligned_accum_fragment1[0], + source_fragment); + + + // + // Store the final result + // + + if (kStoreD0) { + dest0.store(output_fragment[0]); + ++dest0; + } + if (kStoreD1) { + dest1.store(output_fragment[1]); + ++dest1; + } + if (writeToD2) { + dest2.store(output_fragment[2]); + ++dest2; + } + } + } + +private: + + static_assert(kPartitionsK == 1 || Base::kFragmentsPerIteration == 1, "One of these must be exactly 1."); + + template + struct acc2smem_source_needed; + + template + struct acc2smem_source_needed> { + template + CUTLASS_DEVICE + static void helper(AccumulatorFragmentIterator accum_fragment_iterator, + WarpTileIterator &warp_tile_iterator) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Advance; i++) { + ++accum_fragment_iterator; + } + + typename AccumulatorFragmentIterator::Fragment accum_fragment; + accum_fragment_iterator.load(accum_fragment); + warp_tile_iterator.store(accum_fragment); + } + + CUTLASS_DEVICE + static void push(size_t pos, + AccumulatorFragmentIterator const &iterator_begin, + WarpTileIterator &warp_tile_iterator) { + int dummy[] = {(pos == Seq) && (helper(iterator_begin, warp_tile_iterator), 0)...}; + } + }; + + /// Helper to invoke the output functor over each vector of output + CUTLASS_DEVICE + void apply_output_operator_( + typename OutputTileIterator::Fragment (&output_fragment)[3], + OutputOp0 const &output_op0, + OutputOp1 const &output_op1, + OutputOp2 const &output_op2, + typename SharedLoadIterator::Fragment const& aligned_accum_fragment0, + typename SharedLoadIterator::Fragment const& aligned_accum_fragment1, + typename OutputTileIterator::Fragment const (&source_fragment)[2]) { + + OutputAccessType* output_frag_ptr[3] = { + reinterpret_cast(&output_fragment[0]), + reinterpret_cast(&output_fragment[1]), + reinterpret_cast(&output_fragment[2]) + }; + + AccumulatorAccessType const *compute_frag_ptr[2] = { + reinterpret_cast(&aligned_accum_fragment0), + reinterpret_cast(&aligned_accum_fragment1) + }; + + OutputAccessType const *source_frag_ptr[2] = { + reinterpret_cast(&source_fragment[0]), + reinterpret_cast(&source_fragment[1]) + }; + + int const kOutputOpIterations = + OutputTileIterator::Fragment::kElements / OutputTileIterator::kElementsPerAccess; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kOutputOpIterations; ++i) { + + // Call the output operators + output_frag_ptr[0][i] = output_op0(compute_frag_ptr[0][i], source_frag_ptr[0][i]); + output_frag_ptr[1][i] = output_op1(compute_frag_ptr[1][i], source_frag_ptr[1][i]); + output_frag_ptr[2][i] = output_op2(output_frag_ptr[0][i], output_frag_ptr[1][i]); + } + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace epilogue +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////// diff --git a/examples/45_dual_gemm/threadblock/dual_mma_base.h b/examples/45_dual_gemm/threadblock/dual_mma_base.h new file mode 100644 index 0000000000..3a25da9c2c --- /dev/null +++ b/examples/45_dual_gemm/threadblock/dual_mma_base.h @@ -0,0 +1,232 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/threadblock/mma_base.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy0_, + /// B1-specific version of the policy (concept: MmaPolicy) + typename Policy1_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class DualMmaBase { + public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + + ///< Policy describing tuning details + using Policy0 = Policy0_; + using Policy1 = Policy1_; + + // + // Dependent types + // + + /// Warp-level Mma + using Operator0 = typename Policy0::Operator; + using Operator1 = typename Policy1::Operator; + + /// Shape describing the overall GEMM computed from shared memory + /// by each warp. + using WarpGemm = typename Policy0::Operator::Shape; + + /// Shape describing the number of warps filling the CTA + using WarpCount = GemmShape; + + /// Number of warp-level GEMM oeprations + static int const kWarpGemmIterations = + (WarpGemm::kK / Operator0::Policy::MmaShape::kK); + + /// Number of stages + static int const kStages = Stages; + + /// Tensor reference to the A operand + using TensorRefA = TensorRef; + + /// Tensor reference to the B operand + using TensorRefB0 = TensorRef; + using TensorRefB1 = TensorRef; + + static_assert(kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + static_assert((kWarpGemmIterations % 2) == 0, + "Inner loop iteration must be an even number."); + + // + // Nested structs + // + + /// Shared storage object needed by threadblock-scoped GEMM + class SharedStorage { + public: + // + // Type definitions + // + + /// Shape of the A matrix operand in shared memory + using ShapeA = MatrixShape; + + /// Shape of the B matrix operand in shared memory + using ShapeB0 = + MatrixShape; + using ShapeB1 = + MatrixShape; + + public: + // + // Data members + // + + /// Buffer for A operand + AlignedBuffer operand_A; + + /// Buffer for B operand + AlignedBuffer operand_B0; + AlignedBuffer operand_B1; + + public: + + // + // Methods + // + + /// Returns a layout object for the A matrix + CUTLASS_DEVICE + static typename Operator0::LayoutA LayoutA() { + return Operator0::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); + } + + /// Returns a layout object for the B matrix + CUTLASS_HOST_DEVICE + static typename Operator0::LayoutB LayoutB0() { + return Operator0::LayoutB::packed({ShapeB0::kRow, ShapeB0::kColumn}); + } + + /// Returns a layout object for the B matrix + CUTLASS_HOST_DEVICE + static typename Operator1::LayoutB LayoutB1() { + return Operator1::LayoutB::packed({ShapeB1::kRow, ShapeB1::kColumn}); + } + + /// Returns a TensorRef to the A operand + CUTLASS_HOST_DEVICE + TensorRefA operand_A_ref() { + return TensorRefA{operand_A.data(), LayoutA()}; + } + + /// Returns a TensorRef to the B operand + CUTLASS_HOST_DEVICE + TensorRefB0 operand_B0_ref() { + return TensorRefB0{operand_B0.data(), LayoutB0()}; + } + CUTLASS_HOST_DEVICE + TensorRefB1 operand_B1_ref() { + return TensorRefB1{operand_B1.data(), LayoutB1()}; + } + }; + + protected: + + // + // Data members + // + + /// Iterator to load a warp-scoped tile of A operand from shared memory + typename Operator0::IteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of B operand from shared memory + typename Operator0::IteratorB warp_tile_iterator_B0_; + typename Operator1::IteratorB warp_tile_iterator_B1_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + DualMmaBase( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + SharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx + ): + warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), + warp_tile_iterator_B0_(shared_storage.operand_B0_ref(), lane_idx), + warp_tile_iterator_B1_(shared_storage.operand_B1_ref(), lane_idx) { + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/45_dual_gemm/threadblock/dual_mma_multistage.h b/examples/45_dual_gemm/threadblock/dual_mma_multistage.h new file mode 100644 index 0000000000..485922ef2e --- /dev/null +++ b/examples/45_dual_gemm/threadblock/dual_mma_multistage.h @@ -0,0 +1,775 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +#include "cutlass/gemm/threadblock/mma_base.h" +#include "dual_mma_base.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace gemm { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B0 operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB0_, + /// Iterates over tiles of B0 operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB0_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Iterates over tiles of B1 operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB1_, + /// Iterates over tiles of B1 operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB1_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy0_, + /// B1-specific version of the policy (concept: MmaPolicy) + typename Policy1_, + /// Number of stages, + int Stages, + /// Use zfill or predicate for out-of-bound cp.async + SharedMemoryClearOption SharedMemoryClear = SharedMemoryClearOption::kNone, + /// Used for partial specialization + typename Enable = bool> +class DualMmaMultistage : + public DualMmaBase { +public: + ///< Base class + using Base = DualMmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B0 operand in global memory + using IteratorB0 = IteratorB0_; + ///< Iterates over tiles of B1 operand in global memory + using IteratorB1 = IteratorB1_; + ///< Data type of accumulator matrix + using ElementC = ElementC_; + ///< Layout of accumulator matrix + using LayoutC = LayoutC_; + ///< Policy describing tuning details + using Policy0 = Policy0_; + using Policy1 = Policy1_; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB0 = SmemIteratorB0_; + using SmemIteratorB1 = SmemIteratorB1_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + // + // Dependent types + // + + /// Fragment of accumulator tile + using FragmentC = typename Policy0::Operator::FragmentC; + + /// Warp-level Mma + using Operator0 = typename Policy0::Operator; + using Operator1 = typename Policy1::Operator; + + /// Minimum architecture is Sm80 to support cp.async + using ArchTag = arch::Sm80; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator0::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB0 = Operator0::kTransformB; + static ComplexTransform const kTransformB1 = Operator1::kTransformB; + + /// Internal structure exposed for introspection. + struct Detail { + + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = + IteratorA::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB = + IteratorB0::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand A + static int const kAccessesPerGroupA = + (AsyncCopyIterationsPerStageA + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB = + (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + }; + + private: + + using WarpLoadedFragmentA = typename Operator0::FragmentA; + using WarpLoadedFragmentB0 = typename Operator0::FragmentB; + using WarpLoadedFragmentB1 = typename Operator1::FragmentB; + using WarpTransformedFragmentA = typename Operator0::TransformedFragmentA; + using WarpTransformedFragmentB0 = typename Operator0::TransformedFragmentB; + using WarpTransformedFragmentB1 = typename Operator1::TransformedFragmentB; + + private: + + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB0 smem_iterator_B0_; + SmemIteratorB1 smem_iterator_B1_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + DualMmaMultistage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx + ): + Base(shared_storage, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_B0_(shared_storage.operand_B0_ref(), thread_idx), + smem_iterator_B1_(shared_storage.operand_B1_ref(), thread_idx) + { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B0_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + this->warp_tile_iterator_B1_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance(IteratorA &iterator_A, IteratorB0 &iterator_B0, IteratorB1 &iterator_B1, + int group_start_A = 0, int group_start_B = 0) { + iterator_A.set_iteration_index(group_start_A * + IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupA; ++j) { + if (group_start_A + j < Detail::AsyncCopyIterationsPerStageA) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_A.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_A.valid()); + } + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + } + + iterator_B0.set_iteration_index(group_start_B * + IteratorB0::kAccessesPerVector); + iterator_B1.set_iteration_index(group_start_B * + IteratorB1::kAccessesPerVector); + this->smem_iterator_B0_.set_iteration_index(group_start_B); + this->smem_iterator_B1_.set_iteration_index(group_start_B); + + // Async Copy for operand B0 + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { + typename IteratorB0::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B0_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB0::ThreadMap::kElementsPerAccess / + IteratorB0::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB0::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B0.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_B0.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_B0.valid()); + } + + ++iterator_B0; + } + ++this->smem_iterator_B0_; + } + } + // Async Copy for operand B1 + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::kAccessesPerGroupB; ++j) { + if (group_start_B + j < Detail::AsyncCopyIterationsPerStageB) { + typename IteratorB1::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B1_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorB1::ThreadMap::kElementsPerAccess / + IteratorB1::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { + auto gmem_ptr = iterator_B1.get(); + + if (SharedMemoryClear == SharedMemoryClearOption::kZfill) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, gmem_ptr, iterator_B1.valid()); + } else { + cutlass::arch::cp_async( + dst_ptr + v, gmem_ptr, iterator_B1.valid()); + } + + ++iterator_B1; + } + ++this->smem_iterator_B1_; + } + } + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC &accum0, + FragmentC &accum1, + ///< iterator over A operand in global memory + IteratorA iterator_A, + ///< iterator over B operand in global memory + IteratorB0 iterator_B0, + IteratorB1 iterator_B1, + ///< initial value of accumulator + FragmentC const &src_accum0, + FragmentC const &src_accum1 + ) { + + // + // Prologue + // + + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; + ++stage, --gemm_k_iterations) { + + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B0.clear_mask(gemm_k_iterations == 0); + iterator_B1.clear_mask(gemm_k_iterations == 0); + + iterator_A.set_iteration_index(0); + this->smem_iterator_A_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_A_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + int src_bytes = (iterator_A.valid() ? kSrcBytes : 0); + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + + iterator_B0.set_iteration_index(0); + iterator_B1.set_iteration_index(0); + this->smem_iterator_B0_.set_iteration_index(0); + this->smem_iterator_B1_.set_iteration_index(0); + + // Async Copy for operand B0 + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB0::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B0_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB0::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorB0::ThreadMap::kElementsPerAccess / + IteratorB0::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B0.get(), iterator_B0.valid()); + + ++iterator_B0; + } + + ++this->smem_iterator_B0_; + } + // Async Copy for operand B1 + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB1::AccessType *dst_ptr = + reinterpret_cast( + this->smem_iterator_B1_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB1::kAccessesPerVector; ++v) { + int const kSrcBytes = + sizeof_bits::value * + IteratorB1::ThreadMap::kElementsPerAccess / + IteratorB1::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B1.get(), iterator_B1.valid()); + + ++iterator_B1; + } + + ++this->smem_iterator_B1_; + } + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B0.add_tile_offset({1, 0}); + iterator_B1.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B0_.add_tile_offset({1, 0}); + this->smem_iterator_B1_.add_tile_offset({1, 0}); + + // Defines the boundary of a stage of cp.async. + cutlass::arch::cp_async_fence(); + } + + // Perform accumulation in the 'd' output operand + accum0 = src_accum0; + accum1 = src_accum1; + + // + // Clear the remaining tiles of SMEM. This is a functional requirement for some kernels + // so that all accumulator elements outside the GEMM footprint are zero. + // + + if (SharedMemoryClear == SharedMemoryClearOption::kClearLastStage) { + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA last_smem_iterator_A(this->smem_iterator_A_); + + typename IteratorA::AccessType zero_A; + zero_A.clear(); + + last_smem_iterator_A.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast( + last_smem_iterator_A.get()); + + *dst_ptr = zero_A; + + ++last_smem_iterator_A; + } + + typename IteratorB0::AccessType zero_B; + zero_B.clear(); + + /// Iterator to write threadblock-scoped tile of B0 operand to shared memory + SmemIteratorB0 last_smem_iterator_B0(this->smem_iterator_B0_); + last_smem_iterator_B0.set_iteration_index(0); + + // Async Copy for operand B0 + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB0::AccessType *dst_ptr = + reinterpret_cast( + last_smem_iterator_B0.get()); + + *dst_ptr = zero_B; + + ++last_smem_iterator_B0; + } + + /// Iterator to write threadblock-scoped tile of B1 operand to shared memory + SmemIteratorB1 last_smem_iterator_B1(this->smem_iterator_B1_); + last_smem_iterator_B1.set_iteration_index(0); + + // Async Copy for operand B1 + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + + typename IteratorB1::AccessType *dst_ptr = + reinterpret_cast( + last_smem_iterator_B1.get()); + + *dst_ptr = zero_B; + + ++last_smem_iterator_B1; + } + } + + // Waits until stages up to the previous (kStages-2)th stage have committed. + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpLoadedFragmentA warp_loaded_frag_A[2]; + WarpLoadedFragmentB0 warp_loaded_frag_B0[2]; + WarpLoadedFragmentB1 warp_loaded_frag_B1[2]; + WarpTransformedFragmentA warp_transformed_frag_A[2]; + WarpTransformedFragmentB0 warp_transformed_frag_B0[2]; + WarpTransformedFragmentB1 warp_transformed_frag_B1[2]; + + Operator0 warp_mma0; + Operator1 warp_mma1; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B0_.set_kgroup_index(0); + this->warp_tile_iterator_B1_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); + this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[0]); + this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[0]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B0_; + ++this->warp_tile_iterator_B1_; + + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B0.clear_mask(gemm_k_iterations == 0); + iterator_B1.clear_mask(gemm_k_iterations == 0); + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + warp_mma0.transform(warp_transformed_frag_A[0], warp_transformed_frag_B0[0], + warp_loaded_frag_A[0], warp_loaded_frag_B0[0]); + warp_mma1.transform(warp_transformed_frag_A[0], warp_transformed_frag_B1[0], + warp_loaded_frag_A[0], warp_loaded_frag_B1[0]); + + // tf32x3 kernels use staging accumulation. warp_mma uses a temporary + // accumulator and this temporary accumulator is added to the final + // accumulator once in every mainloop iteration. + plus plus_accum; + + FragmentC tmp_accum0, tmp_accum1; + + if (platform::is_same::value + || platform::is_same::value) { + + tmp_accum0.clear(); + tmp_accum1.clear(); + } + + // + // Mainloop + // + + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-Base::kStages + 1);) { + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; + ++warp_mma_k) { + + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_B0_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_B1_.set_kgroup_index((warp_mma_k + 1) % Base::kWarpGemmIterations); + + this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B0_.load(warp_loaded_frag_B0[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B1_.load(warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B0_; + ++this->warp_tile_iterator_B1_; + + if (warp_mma_k > 0) { + warp_mma0.transform(warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B0[warp_mma_k % 2], + warp_loaded_frag_A[warp_mma_k % 2], + warp_loaded_frag_B0[warp_mma_k % 2]); + warp_mma1.transform(warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B1[warp_mma_k % 2], + warp_loaded_frag_A[warp_mma_k % 2], + warp_loaded_frag_B1[warp_mma_k % 2]); + } + + if (platform::is_same::value + || platform::is_same::value) { + + warp_mma0( + tmp_accum0, + warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B0[warp_mma_k % 2], + tmp_accum0 + ); + warp_mma1( + tmp_accum1, + warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B1[warp_mma_k % 2], + tmp_accum1 + ); + + if (warp_mma_k == 0) { + accum0 = plus_accum(accum0, tmp_accum0); + accum1 = plus_accum(accum1, tmp_accum1); + tmp_accum0.clear(); + tmp_accum1.clear(); + } + } else { + warp_mma0( + accum0, + warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B0[warp_mma_k % 2], + accum0 + ); + warp_mma1( + accum1, + warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B1[warp_mma_k % 2], + accum1 + ); + } + + // Issue global->shared copies for the this stage + if (warp_mma_k < Base::kWarpGemmIterations - 1) { + int group_start_iteration_A, group_start_iteration_B; + + group_start_iteration_A = warp_mma_k * Detail::kAccessesPerGroupA; + group_start_iteration_B = warp_mma_k * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, iterator_B0, iterator_B1, group_start_iteration_A, + group_start_iteration_B); + } + + if (warp_mma_k + 2 == Base::kWarpGemmIterations) { + int group_start_iteration_A, group_start_iteration_B; + group_start_iteration_A = + (warp_mma_k + 1) * Detail::kAccessesPerGroupA; + group_start_iteration_B = + (warp_mma_k + 1) * Detail::kAccessesPerGroupB; + + copy_tiles_and_advance(iterator_A, iterator_B0, iterator_B1, group_start_iteration_A, + group_start_iteration_B); + + // Inserts a memory fence between stages of cp.async instructions. + cutlass::arch::cp_async_fence(); + + // Waits until stages up to the previous (kStages-2)th stage have committed. + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next stage + iterator_A.add_tile_offset({0, 1}); + iterator_B0.add_tile_offset({1, 0}); + iterator_B1.add_tile_offset({1, 0}); + + this->smem_iterator_A_.add_tile_offset({0, 1}); + this->smem_iterator_B0_.add_tile_offset({1, 0}); + this->smem_iterator_B1_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B0_.add_tile_offset({-Base::kStages, 0}); + this->smem_iterator_B1_.add_tile_offset({-Base::kStages, 0}); + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_A_.add_tile_offset( + {0, -Base::kStages * Policy0::kPartitionsK * + Base::kWarpGemmIterations}); + this->warp_tile_iterator_B0_.add_tile_offset( + {-Base::kStages * Policy0::kPartitionsK * + Base::kWarpGemmIterations, + 0}); + this->warp_tile_iterator_B1_.add_tile_offset( + {-Base::kStages * Policy1::kPartitionsK * + Base::kWarpGemmIterations, + 0}); + smem_read_stage_idx = 0; + } else { + ++smem_read_stage_idx; + } + + --gemm_k_iterations; + iterator_A.clear_mask(gemm_k_iterations == 0); + iterator_B0.clear_mask(gemm_k_iterations == 0); + iterator_B1.clear_mask(gemm_k_iterations == 0); + } + + // Do any conversions feeding the first stage at the end of the loop so + // we can start right away on mma instructions + if (warp_mma_k + 1 == Base::kWarpGemmIterations) { + warp_mma0.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], + warp_transformed_frag_B0[(warp_mma_k + 1) % 2], + warp_loaded_frag_A[(warp_mma_k + 1) % 2], + warp_loaded_frag_B0[(warp_mma_k + 1) % 2]); + warp_mma1.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], + warp_transformed_frag_B1[(warp_mma_k + 1) % 2], + warp_loaded_frag_A[(warp_mma_k + 1) % 2], + warp_loaded_frag_B1[(warp_mma_k + 1) % 2]); + } + } + + } + + if (platform::is_same::value + || platform::is_same::value) { + accum0 = plus_accum(accum0, tmp_accum0); + accum1 = plus_accum(accum1, tmp_accum1); + } + + // commit and drain all pending and predicated cp.async pnz from the GEMM mainloop + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/46_depthwise_simt_conv2dfprop/CMakeLists.txt b/examples/46_depthwise_simt_conv2dfprop/CMakeLists.txt new file mode 100644 index 0000000000..9a9e74c1cf --- /dev/null +++ b/examples/46_depthwise_simt_conv2dfprop/CMakeLists.txt @@ -0,0 +1,36 @@ + +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + + +cutlass_example_add_executable( + 46_depthwise_simt_conv2dfprop + depthwise_simt_conv2dfprop.cu + ) + diff --git a/examples/46_depthwise_simt_conv2dfprop/depthwise_simt_conv2dfprop.cu b/examples/46_depthwise_simt_conv2dfprop/depthwise_simt_conv2dfprop.cu new file mode 100644 index 0000000000..cc7d2f10f8 --- /dev/null +++ b/examples/46_depthwise_simt_conv2dfprop/depthwise_simt_conv2dfprop.cu @@ -0,0 +1,682 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/** +This example shows how to run depthwise 2d convolution kernels using functions and data structures +provided by CUTLASS using SIMT instruction; + +There are 3 types of implementations of depthwise 2d convoltion + 1. kAnalytic + Implicit gemm 2d convoltion algorithm. + 2. kOptimized + An optimized algorithm and supports arbitrary stride and dilation. + 3. kFixedStrideDilation + An optimized algorithm with fixed stride and dilation to reduce the runtime computation and do +more optimizations. + +In general, the perf of kFixedStrideDilation would be better than kOptimized. However, if the filter +size, stride or dilation is large, it would encounter register spilling and may hurt the perf. If +in this case, please use kOptimized. + +For kOptimized and kFixedStrideDilation, in order to fully utilize GPU hardware resources and achieve +better perf, when the output tensor size is large, splitk should be enabled to achieve better perf. + +In this example, it demonstrates how to construct and run a FixedStrideDilation depthwise 2d +convolution kernel. +*/ + +#include +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm.h" +#include "cutlass/conv/kernel/default_depthwise_fprop.h" +#include "cutlass/conv/device/implicit_gemm_convolution.h" +#include "cutlass/conv/device/direct_convolution.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/convolution.h" +#include "cutlass/util/tensor_view_io.h" + +#include "helper.h" + +// The code section below describes datatype for input, output tensors and computation between +// elements +using ElementAccumulator = cutlass::half_t; // Data type of accumulator +using ElementComputeEpilogue = cutlass::half_t; // Data type of epilogue computation (alpha, beta) +using ElementInputA = cutlass::half_t; // Data type of elements in input tensor +using ElementInputB = cutlass::half_t; // Data type of elements in input tensor +using ElementOutput = cutlass::half_t; // Data type of elements in output tensor + +using LayoutInputA = cutlass::layout::TensorNHWC; +using LayoutInputB = cutlass::layout::TensorNHWC; +using LayoutOutput = cutlass::layout::TensorNHWC; + +// This code section describes whether you want to use tensor cores or regular SIMT cores on GPU SM +using MMAOp = cutlass::arch::OpClassSimt; + +// This code section describes CUDA SM architecture number +using SmArch = cutlass::arch::Sm60; + +// This code section describes the groups a thread block will compute +constexpr int groups_per_cta = 64; + +// This code section describes the output tile a thread block will compute +using ThreadBlockOutputShape = cutlass::conv::TensorNHWCShape<1, 8, 8, groups_per_cta>; + +// This code section describes the filter shape +using FilterShape = cutlass::MatrixShape<3, 3>; + +// Threadblock tile shape +using ThreadblockShape = + cutlass::gemm::GemmShape; + +// This code section describes tile size a warp will computes +// WarpShape::kM = P * Q the warps would process +// WarpShape::kN = groups_per_cta that the warps would process +// WarpShape::kK = filter_size that the warps would process +using WarpShape = cutlass::gemm::GemmShape<16, groups_per_cta, FilterShape::kCount>; + +// This code section describes the size of MMA op +using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + +// This code section describes how threadblocks are scheduled on GPU +using SwizzleThreadBlock = + cutlass::conv::threadblock::DepthwiseDirect2dConvIdentityThreadblockSwizzle< + 1, + ThreadBlockOutputShape::kN, + ThreadBlockOutputShape::kH, + ThreadBlockOutputShape::kW>; + +// Number of pipelines you want to use +constexpr int NumStages = 4; + +// This code section describe iterator algorithm selected is kFixedStrideDilation +static cutlass::conv::IteratorAlgorithm const IteratorAlgorithm = + cutlass::conv::IteratorAlgorithm::kFixedStrideDilation; +using StrideShape = cutlass::MatrixShape<1, 1>; +using DilationShape = cutlass::MatrixShape<1, 1>; + +constexpr int kEpilogueElementsPerAccess = 128 / cutlass::sizeof_bits::value; + +// This code section describes the epilogue part of the kernel, we use default value +using EpilogueOp = cutlass::epilogue::thread::LinearCombination< + ElementOutput, // Data type of output matrix. + kEpilogueElementsPerAccess, // The number of elements per vectorized. + // memory access. This becomes the vector width of + // math instructions in the epilogue too. + ElementAccumulator, // Data type of accumulator + ElementComputeEpilogue, // Data type for alpha/beta in linear combination + cutlass::epilogue::thread::ScaleType::OnlyAlphaScaling>; // Epilogue scaling operation. + +using DepthwiseDirect2dConv = typename cutlass::conv::kernel::DefaultDepthwiseDirect2dConvFprop< + ElementInputA, + LayoutInputA, + ElementInputB, + LayoutInputB, + ElementOutput, + LayoutOutput, + ElementAccumulator, + MMAOp, + SmArch, + ThreadblockShape, + ThreadBlockOutputShape, + FilterShape, + WarpShape, + InstructionShape, + EpilogueOp, + SwizzleThreadBlock, + NumStages, + cutlass::arch::OpMultiplyAdd, + IteratorAlgorithm, + cutlass::conv::StrideSupport::kFixed, + StrideShape, + DilationShape>::Kernel; + +using Direct2dConv = cutlass::conv::device::DirectConvolution; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + bool help; + cutlass::Tensor4DCoord input_size; + cutlass::Tensor4DCoord filter_size; + cutlass::Tensor4DCoord padding; + cutlass::MatrixCoord conv_stride; + cutlass::MatrixCoord dilation; + int groups; + int splitk; + bool reference_check; + bool measure_performance; + int iterations; + bool save_workspace; + ElementComputeEpilogue alpha; + ElementComputeEpilogue beta; + std::string tag; + + Options() + : help(false), + input_size(1, 128, 128, 32), + filter_size(32, 3, 3, 1), + groups(32), + padding(1, 1, 1, 1), + conv_stride(1, 1), + dilation(1, 1), + reference_check(false), + measure_performance(true), + iterations(20), + save_workspace(false), + alpha(1), + beta(0), + splitk(1) {} + + // Verify the problem size is compatible with the CUTLASS Convolution implementation. + bool valid() { + // + // CUTLASS attempts to load 128b vectors of cutlass::half_t (F16) elements. Consequently, + // all pointers, strides, and tensor extents must be divisible by 8 elements. + // + int const kAlignment = 8; + + if ((input_size.c() % kAlignment) || (filter_size.n() % kAlignment)) { + // misaligned tensors + return false; + } + + // depthwise conv + if (groups != input_size.c()) { + return false; + } + + if (filter_size.n() != groups) { + return false; + } + + // Invalid padding + if ((padding.h() != filter_size.h() / 2) || (padding.w() != filter_size.w() / 2)) { + return false; + } + + // Filter size passed through command line does not match filter size template parameter + if (filter_size.h() != FilterShape::kRow || filter_size.w() != FilterShape::kColumn) { + std::cerr << "Filter size passed in (" << filter_size.h() << "x" << filter_size.w() << ") " + << "must match the FilterShape template parameter of the convolution " + << "(" << FilterShape::kRow << "x" << FilterShape::kColumn << "). " + << "To use the filter shape passed in, change the FilterShape template " + << "parameter and recompile this example." + << std::endl; + return false; + } + + return true; + } + + /// Updates input and filter sizes + void update(cutlass::Tensor4DCoord input_size, cutlass::Tensor4DCoord filter_size) { + this->input_size = input_size; + this->filter_size = filter_size; + + padding.n() = filter_size.h() / 2; + padding.h() = filter_size.h() / 2; + padding.w() = filter_size.w() / 2; + padding.c() = filter_size.w() / 2; + } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + } + + if (cmd.check_cmd_line_flag("ref-check")) { + reference_check = true; + } + + if (cmd.check_cmd_line_flag("perf-check")) { + measure_performance = true; + } + + if (cmd.check_cmd_line_flag("save-workspace")) { + save_workspace = true; + } + + cmd.get_cmd_line_argument("n", input_size.n()); + cmd.get_cmd_line_argument("h", input_size.h()); + cmd.get_cmd_line_argument("w", input_size.w()); + cmd.get_cmd_line_argument("c", input_size.c()); + + cmd.get_cmd_line_argument("k", filter_size.n()); + cmd.get_cmd_line_argument("r", filter_size.h()); + cmd.get_cmd_line_argument("s", filter_size.w()); + + cmd.get_cmd_line_argument("g", groups); + + filter_size.c() = 1; + filter_size.n() = input_size.c(); + + cmd.get_cmd_line_argument("alpha", alpha); + cmd.get_cmd_line_argument("beta", beta); + cmd.get_cmd_line_argument("splitk", splitk); + + cmd.get_cmd_line_argument("iterations", iterations); + cmd.get_cmd_line_argument("tag", tag); + + int32_t padding_h = filter_size.h() / 2; + int32_t padding_w = filter_size.w() / 2; + padding = {padding_h, padding_h, padding_w, padding_w}; + } + + /// Prints the usage statement. + std::ostream &print_usage(std::ostream &out) const { + out << "46_depthwise_gemm_fprop example\n\n" + << " This example uses Ampere's Tensor Core operators on F16 data types to compute\n" + << " forward convolution on tensors of layout NHWC.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement.\n\n" + << " --n= Input tensor extent N\n" + << " --h= Input tensor extent H\n" + << " --w= Input tensor extent W\n" + << " --c= Input tensor extent C\n" + << " --k= Filter extent K\n" + << " --r= Filter extent R\n" + << " --s= Filter extent S\n\n" + << " --g= Groups\n\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --splitk= Enable splitK\n\n" + << " --ref-check If set (true), reference check on the host is computed\n" + << " --perf-check If set (true), performance is measured.\n" + << " --iterations= Number of profiling iterations to perform.\n" + << " --save-workspace If set, workspace is written to a text file.\n" + << " --tag= String to replicate across the first column in the results " + "table\n"; + + out << "\n\nExamples:\n\n" + << "$ ./examples/46_depthwise_simt_conv2dfprop/46_depthwise_simt_conv2dfprop --n=32 " + "--h=224 --w=224 --c=128 --k=128 --g=128 --r=3 --s=3\n\n" + << "$ ./examples/46_depthwise_simt_conv2dfprop/46_depthwise_simt_conv2dfprop --n=1 " + "--h=224 --w=224 --c=32 --k=32 --g=32 --r=3 --s=3 --splitk=10 --ref-check\n\n"; + + return out; + } + + /// Computes the output tensor size (NPQK) + cutlass::Tensor4DCoord output_size() const { + return cutlass::Tensor4DCoord( + input_size.n(), + (input_size.h() + padding.n() + padding.h() - filter_size.h()) / conv_stride.row() + 1, + (input_size.w() + padding.w() + padding.c() - filter_size.w()) / conv_stride.column() + 1, + filter_size.n()); + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const { + // Number of multiply-adds = NPQK * CRS + int64_t fmas = + output_size().product() * int64_t(filter_size.h() * filter_size.w() * filter_size.c()); + + // Two flops per multiply-add + return 2.0 * double(fmas) / double(1.0e9) / runtime_s; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +struct Result { + double runtime_ms; + double gflops; + cutlass::Status status; + cutlass::Status reference_check; + cudaError_t error; + + Result() + : runtime_ms(0), + gflops(0), + status(cutlass::Status::kSuccess), + reference_check(cutlass::Status::kInvalid), + error(cudaSuccess) {} + + static std::ostream &print_header(std::ostream &out, Options const &options) { + if (!options.tag.empty()) { + out << "Name,"; + } + + out << "Layer,N,H,W,C,K,R,S,G,stride_h,stride_w,dilation_h,dilation_w,splitK,Runtime,GFLOPs"; + + return out; + } + + std::ostream &print(std::ostream &out, int idx, Options const &options) { + if (!options.tag.empty()) { + out << options.tag << ","; + } + + cutlass::Tensor4DCoord output_size = options.output_size(); + out << "conv_" << idx << "," << options.input_size.n() << "," << options.input_size.h() << "," + << options.input_size.w() << "," << options.input_size.c() << "," + + << options.filter_size.n() << "," << options.filter_size.h() << "," + << options.filter_size.w() << "," + + << options.groups << "," << options.conv_stride.row() << "," << options.conv_stride.column() + << "," + + << options.dilation.row() << "," << options.dilation.column() << "," + + << options.splitk << "," + + << runtime_ms << "," << gflops; + + return out; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Runs one testcase +Result profile_convolution(Options const &options) { + Result result; + + // + // Allocate host-device tensors using the CUTLASS Utilities. + // + + cutlass::HostTensor tensor_a(options.input_size); + cutlass::HostTensor tensor_b(options.filter_size); + cutlass::HostTensor tensor_b_transpose(options.filter_size); + cutlass::HostTensor tensor_c(options.output_size()); + cutlass::HostTensor tensor_d(options.output_size()); + cutlass::HostTensor tensor_ref_d(options.output_size()); + + // + // Initialize tensors + // + + // Fill tensor A on host with uniform-distribution random data + cutlass::reference::host::TensorFillRandomUniform( + tensor_a.host_view(), 1, ElementInputA(5), ElementInputA(-6), 0); + + // Fill tensor B on host with uniform-distribution random data + cutlass::reference::host::TensorFillRandomUniform( + tensor_b.host_view(), 1, ElementInputB(3), ElementInputB(-6), 0); + + // Fill tensor C on host with uniform-distribution random data + cutlass::reference::host::TensorFillRandomUniform( + tensor_c.host_view(), 1, ElementOutput(5), ElementOutput(-6), 0); + + // Fill tensor D on host with zeros + cutlass::reference::host::TensorFill(tensor_d.host_view()); + + // Fill tensor D for reference on host with zeros + cutlass::reference::host::TensorFill(tensor_ref_d.host_view()); + + // Copy data from host to GPU + tensor_a.sync_device(); + tensor_b.sync_device(); + tensor_b_transpose.sync_device(); + tensor_c.sync_device(); + tensor_d.sync_device(); + tensor_ref_d.sync_device(); + + // + // Define arguments for CUTLASS Convolution + // + + cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation; + + // Split P*Q into multiple CTA + int split_k_slices = options.splitk; + + // Construct Conv2dProblemSize with user defined output size + cutlass::conv::Conv2dProblemSize problem_size(options.input_size, + options.filter_size, + options.padding, + options.conv_stride, + options.dilation, + options.output_size(), + mode, + split_k_slices, + options.groups); + + // Construct Direc2dConv::Argument structure with conv2d + // problem size, data pointers, and epilogue values + typename Direct2dConv::Arguments arguments{problem_size, + tensor_a.device_ref(), + tensor_b.device_ref(), + tensor_c.device_ref(), + tensor_d.device_ref(), + {options.alpha, options.beta}, + tensor_b_transpose.device_ref()}; + + // + // Initialize CUTLASS Convolution + // + + Direct2dConv implicit_gemm_op; + + size_t workspace_size = implicit_gemm_op.get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + result.status = implicit_gemm_op.can_implement(arguments); + CUTLASS_CHECK(result.status); + + result.status = implicit_gemm_op.initialize(arguments, workspace.get()); + CUTLASS_CHECK(result.status); + + // + // Launch initialized CUTLASS kernel + // + result.status = implicit_gemm_op(); + + CUTLASS_CHECK(result.status); + + // + // Optional reference check + // + + if (options.reference_check) { + std::cout << "Verification on host...\n"; + + // Compute with reference implementation + cutlass::reference::host::Conv2dFprop< + ElementInputA, + LayoutInputA, + ElementInputB, + LayoutInputB, + ElementOutput, + LayoutOutput, + ElementComputeEpilogue, + ElementAccumulator >(problem_size, + tensor_a.host_ref(), + tensor_b.host_ref(), + tensor_c.host_ref(), + tensor_ref_d.host_ref(), + options.alpha, + options.beta); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + tensor_d.sync_host(); + + bool passed = + cutlass::reference::host::TensorEquals(tensor_d.host_view(), tensor_ref_d.host_view()); + + if (!passed) { + result.reference_check = cutlass::Status::kErrorInternal; + std::cout << "ERROR - results miscompared.\n"; + } else { + result.reference_check = cutlass::Status::kSuccess; + std::cout << "Passed.\n"; + } + } else { + result.reference_check = cutlass::Status::kInvalid; + } + + if (options.save_workspace) { + std::stringstream ss; + + ss << "46_depthwise_simt_conv2dfprop" << options.input_size.n() << "x" << options.input_size.h() + << "x" << options.input_size.w() << "x" << options.input_size.c() << "_" + << options.filter_size.n() << "x" << options.filter_size.h() << "x" + << options.filter_size.w() << "x" << options.filter_size.c() << ".dat"; + + std::ofstream output_workspace(ss.str()); + + output_workspace << "Input = \n" + << tensor_a.host_view() << "\n\n" + << "Filters = \n" + << tensor_b.host_view() << "\n\n"; + + if (options.reference_check) { + output_workspace << "Reference = \n" << tensor_ref_d.host_view() << "\n\n"; + } + + output_workspace << "Computed = \n" << tensor_d.host_view() << std::endl; + + std::cout << "Results written to '" << ss.str() << "'." << std::endl; + } + + // + // Performance measurement + // + + if (options.measure_performance) { + cudaEvent_t events[2]; + + for (auto &event : events) { + result.error = cudaEventCreate(&event); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + } + + // Record an event at the start of a series of convolution operations. + result.error = cudaEventRecord(events[0]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Launch a sequence of implicit GEMM operations on the device + for (int iteration = 0; iteration < options.iterations; ++iteration) { + result.status = implicit_gemm_op(); + CUTLASS_CHECK(result.status); + } + + // Record an event when the convolutions have been launched. + result.error = cudaEventRecord(events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Wait for work on the device to complete. + result.error = cudaEventSynchronize(events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) + << std::endl; + return result; + } + + // Measure elapsed runtime + float runtime_ms = 0; + result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Print average runtime and GFLOPs. + result.runtime_ms = double(runtime_ms) / double(options.iterations); + result.gflops = options.gflops(result.runtime_ms / 1000.0); + + // Cleanup + for (auto event : events) { + (void)cudaEventDestroy(event); + } + } + + return result; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + bool notSupported = false; + + cudaDeviceProp props; + CUDA_CHECK(cudaGetDeviceProperties(&props, 0)); + + if (!(props.major >= 6)) { + std::cerr << "Run on a machine with compute capability at least 60." << std::endl; + notSupported = true; + } + + if (notSupported) { + return 0; + } + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // Execute one problem size + if (!options.valid()) { + std::cerr << "Invalid problem." << std::endl; + return -1; + } + + Result result = profile_convolution(options); + + Result::print_header(std::cout, options) << std::endl; + result.print(std::cout, 1, options) << std::endl; + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/47_ampere_gemm_universal_streamk/CMakeLists.txt b/examples/47_ampere_gemm_universal_streamk/CMakeLists.txt new file mode 100644 index 0000000000..00be87ede6 --- /dev/null +++ b/examples/47_ampere_gemm_universal_streamk/CMakeLists.txt @@ -0,0 +1,45 @@ + +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +cutlass_example_add_executable( + 47_ampere_gemm_universal_streamk + ampere_gemm_universal_streamk.cu + ) + +# Deliberately test non-square sizes to ensure that internal transpose is +# not triggered when using SM80 EVT +set(TEST_COMMAND_00 --m=512 --n=768 --k=1152) + +cutlass_example_add_executable( + 47_ampere_gemm_universal_streamk_broadcast + ampere_gemm_universal_streamk_broadcast.cu + TEST_COMMAND_OPTIONS + TEST_COMMAND_00 + ) diff --git a/examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk.cu b/examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk.cu new file mode 100644 index 0000000000..76bd0979de --- /dev/null +++ b/examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk.cu @@ -0,0 +1,592 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*************************************************************************************************** + Example contrasting the Stream-K parallel decomposition for GEMM threadblocks versus the + "classic data-parallel" and "Split-K" decompositions. + + For more details regarding the Stream-K method, see "Stream-K: Work-centric Parallel Decomposition + for Dense Matrix-Matrix Multiplication on the GPU" (https://arxiv.org/abs/2301.03598) + + Requires NVIDIA Ampere or newer device (SM80+). + + - To lock persistence mode, power (400W), clocks (1005MHz) for evaluation (assumes device 0 and A100) + + cutlass$ sudo nvidia-smi -pm 1 -i 0 + + cutlass$ sudo nvidia-smi -i 0 -pl 400 + + cutlass$ sudo nvidia-smi -i 0 -lgc 1005 + + - Build and run: + + cutlass$ mkdir build + + cutlass$ cd build + + cutlass/build$ cmake .. -DCUTLASS_NVCC_ARCHS=80 + + cutlass/build$ make 47_ampere_gemm_universal_streamk + + cutlass/build$ ./examples/47_ampere_gemm_universal_streamk/47_ampere_gemm_universal_streamk + + 10000 timing iterations of 2048 x 2048 x 2048 matrix-matrix multiply + + Basic data-parallel GEMM + Disposition: Passed + Avg runtime: 0.112633 ms + GFLOPs: 152530 + + StreamK GEMM with default load-balancing + Disposition: Passed + Avg runtime: 0.0941929 ms + GFLOPs: 182390 + Speedup vs Basic-DP: 1.196 + + StreamK emulating basic data-parallel GEMM + Disposition: Passed + Avg runtime: 0.113119 ms + GFLOPs: 151875 + Speedup vs Basic-DP: 0.996 + + Basic split-K GEMM with tile-splitting factor 2 + Disposition: Passed + Avg runtime: 0.104772 ms + GFLOPs: 163973 + + StreamK emulating Split-K GEMM with tile-splitting factor 2 + Disposition: Passed + Avg runtime: 0.105379 ms + GFLOPs: 163029 + Speedup vs Basic-SplitK: 0.994 + + **************************************************************************************************/ + +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm_universal.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "helper.h" + + + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations (cutlass_tensorop_h16816gemm_128x128_32x4_nn_align8) +///////////////////////////////////////////////////////////////////////////////////////////////// + +// A matrix configuration +using ElementA = cutlass::half_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = cutlass::half_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::RowMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// C/D matrix configuration +using ElementC = cutlass::half_t; // Element type for C and D matrix operands +using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C/D matrices in units of elements (up to 16 bytes) + +// Multiply-accumulate blocking/pipelining details +using ElementAccumulator = cutlass::half_t; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm80; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; // Threadblock-level tile size (concept: GemmShape) +using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; // Warp-level tile size (concept: GemmShape) +using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; // Instruction-level tile size (concept: GemmShape) +constexpr int NumStages = 4; // Number of global->shared pipeline stages used in the GEMM mainloop + +// Epilogue output operator +using EpilogueOp = cutlass::epilogue::thread::LinearCombination< + ElementC, // Element type for C and D matrix operands + AlignmentC, // Memory access granularity of C and D matrix in units of elements + ElementAccumulator, // Element type from internal accumaccumulation + ElementAccumulator>; // Data type used to compute linear combination + +// Reference device GEMM implementation type +using DeviceGemmReference = cutlass::reference::device::Gemm< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + ElementAccumulator>; + +// Classic data-parallel device GEMM implementation type +using DeviceGemmBasic = cutlass::gemm::device::GemmUniversal< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + NumStages, + AlignmentA, + AlignmentB>; + +// StreamK device GEMM implementation type +using DeviceGemmStreamK = cutlass::gemm::device::GemmUniversal< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOp, + cutlass::gemm::threadblock::ThreadblockSwizzleStreamK, // <-- Only difference + NumStages, + AlignmentA, + AlignmentB>; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Result structure +struct Result +{ + double avg_runtime_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + Result( + double avg_runtime_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess) + : + avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(true) + {} + +}; + + +/// Command line options parsing +struct Options +{ + std::string command_name; + bool help; + cutlass::gemm::GemmCoord problem_size; + float alpha; + float beta; + int split_k_factor; + int avail_sms; + bool reference_check; + int iterations; + + cutlass::HostTensor tensor_a; + cutlass::HostTensor tensor_b; + cutlass::HostTensor tensor_c; + cutlass::HostTensor tensor_d; + cutlass::HostTensor tensor_ref_d; + + Options(std::string command_name) : + command_name(command_name), + help(false), + problem_size({2048, 2048, 2048}), + alpha(1.0f), + beta(0.0f), + split_k_factor(1), + avail_sms(-1), // Number of device SMs to use is unlimited + reference_check(true), + iterations(10000) + {} + + bool valid() const + { + return true; + } + + void parse(int argc, char const **args) + { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + } + + cmd.get_cmd_line_argument("m", problem_size.m()); + cmd.get_cmd_line_argument("n", problem_size.n()); + cmd.get_cmd_line_argument("k", problem_size.k()); + cmd.get_cmd_line_argument("alpha", alpha); + cmd.get_cmd_line_argument("beta", beta); + cmd.get_cmd_line_argument("split", split_k_factor); + cmd.get_cmd_line_argument("iterations", iterations); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const + { + out + << "Performs a GEMM computation.\n" + << "\n" + << "Options:\n" + << "\n" + << " --help If specified, displays this usage statement.\n\n" + << " --m= GEMM M dimension\n" + << " --n= GEMM N dimension\n" + << " --k= GEMM K dimension\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --split= Split-K factor to emulate\n\n" + << " --iterations= Number of profiling iterations to perform.\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << command_name << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + return 2.0 * double(problem_size.product()) / double(1.0e9) / runtime_s; + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Populates a DeviceGemmBasic::Arguments structure from the given commandline options +typename DeviceGemmBasic::Arguments args_from_options( + const DeviceGemmBasic &device_gemm, + const Options &options, + cutlass::HostTensor &tensor_a, + cutlass::HostTensor &tensor_b, + cutlass::HostTensor &tensor_c, + cutlass::HostTensor &tensor_d) +{ + return typename DeviceGemmBasic::Arguments( + cutlass::gemm::GemmUniversalMode::kGemm, // universal mode + options.problem_size, // problem_size + options.split_k_factor, // batch count / splitk slices + { // epilogue parameters + ElementAccumulator(options.alpha), + ElementAccumulator(options.beta) + }, + tensor_a.device_data(), // ptr_A + tensor_b.device_data(), // ptr_B + tensor_c.device_data(), // ptr_C + tensor_d.device_data(), // ptr_D + options.problem_size.mk().product(), // batch_stride_A + options.problem_size.nk().product(), // batch_stride_B + options.problem_size.mn().product(), // batch_stride_C + options.problem_size.mn().product(), // batch_stride_D + tensor_a.layout().stride(0), // stride_a + tensor_b.layout().stride(0), // stride_b + tensor_c.layout().stride(0), // stride_c + tensor_d.layout().stride(0)); // stride_d +} + +/// Populates a DeviceGemmStreamK::Arguments structure from the given commandline options +typename DeviceGemmStreamK::Arguments args_from_options( + const DeviceGemmStreamK &device_gemm, + const Options &options, + cutlass::HostTensor &tensor_a, + cutlass::HostTensor &tensor_b, + cutlass::HostTensor &tensor_c, + cutlass::HostTensor &tensor_d) +{ + return typename DeviceGemmStreamK::Arguments( + cutlass::gemm::GemmUniversalMode::kGemm, // universal mode + options.problem_size, // problem_size + options.split_k_factor, // batch count / splitk slices + { // epilogue parameters + ElementAccumulator(options.alpha), + ElementAccumulator(options.beta) + }, + tensor_a.device_data(), // ptr_A + tensor_b.device_data(), // ptr_B + tensor_c.device_data(), // ptr_C + tensor_d.device_data(), // ptr_D + options.problem_size.mk().product(), // batch_stride_A + options.problem_size.nk().product(), // batch_stride_B + options.problem_size.mn().product(), // batch_stride_C + options.problem_size.mn().product(), // batch_stride_D + tensor_a.layout().stride(0), // stride_a + tensor_b.layout().stride(0), // stride_b + tensor_c.layout().stride(0), // stride_c + tensor_d.layout().stride(0), // stride_d + options.avail_sms); // avail_sms +} + + +/// Execute a given example GEMM computation +template +Result run(std::string description, Options &options) +{ + // Display test description + std::cout << std::endl << description << std::endl; + + // Zero-initialize test output matrix D + cutlass::reference::host::TensorFill(options.tensor_d.host_view()); + options.tensor_d.sync_device(); + + // Instantiate CUTLASS kernel depending on templates + DeviceGemmT device_gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of DeviceGemmT + auto arguments = args_from_options(device_gemm, options, options.tensor_a, options.tensor_b, options.tensor_c, options.tensor_d); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = DeviceGemmT::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check the problem size is supported or not + CUTLASS_CHECK(device_gemm.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(device_gemm.initialize(arguments, workspace.get())); + + // Correctness / Warmup iteration + CUTLASS_CHECK(device_gemm()); + + // Copy output data from CUTLASS and reference kernel to host for comparison + options.tensor_d.sync_host(); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + result.passed = cutlass::reference::host::TensorEquals( + options.tensor_d.host_view(), + options.tensor_ref_d.host_view()); + + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + + // Run profiling loop + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(device_gemm()); + } + timer.stop(); + + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPs: " << result.gflops << std::endl; + } + + if (!result.passed) { + exit(-1); + } + + return result; +} + + +/// Program entrypoint +int main(int argc, const char **argv) +{ + // CUTLASS must be compiled with CUDA 11.0 Toolkit to run these examples. + if (!(__CUDACC_VER_MAJOR__ >= 11)) { + std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl; + + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + // Current device must must have compute capability at least 80 + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + if (!((props.major * 10 + props.minor) >= 80)) + { + std::cerr << "Ampere Tensor Core operations must be run on a machine with compute capability at least 80." + << std::endl; + + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + // Parse commandline options + Options options("ampere_streamk_gemm"); + options.parse(argc, argv); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + std::cout << + options.iterations << " timing iterations of " << + options.problem_size.m() << " x " << + options.problem_size.n() << " x " << + options.problem_size.k() << " matrix-matrix multiply" << std::endl; + + if (!options.valid()) { + std::cerr << "Invalid problem." << std::endl; + return -1; + } + + + // + // Initialize GEMM datasets + // + + // Initialize tensors using CUTLASS helper functions + options.tensor_a.resize(options.problem_size.mk()); // <- Create matrix A with dimensions M x K + options.tensor_b.resize(options.problem_size.kn()); // <- Create matrix B with dimensions K x N + options.tensor_c.resize(options.problem_size.mn()); // <- Create matrix C with dimensions M x N + options.tensor_d.resize(options.problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from CUTLASS kernel + options.tensor_ref_d.resize(options.problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from reference kernel + + // Fill matrix A on host with uniform-random data [-2, 2] + cutlass::reference::host::TensorFillRandomUniform( + options.tensor_a.host_view(), + 1, + ElementA(2), + ElementA(-2), + 0); + + // Fill matrix B on host with uniform-random data [-2, 2] + cutlass::reference::host::TensorFillRandomUniform( + options.tensor_b.host_view(), + 1, + ElementB(2), + ElementB(-2), + 0); + + // Fill matrix C on host with uniform-random data [-2, 2] + cutlass::reference::host::TensorFillRandomUniform( + options.tensor_c.host_view(), + 1, + ElementC(2), + ElementC(-2), + 0); + + + // + // Compute reference output + // + + // Copy data from host to GPU + options.tensor_a.sync_device(); + options.tensor_b.sync_device(); + options.tensor_c.sync_device(); + + // Zero-initialize reference output matrix D + cutlass::reference::host::TensorFill(options.tensor_ref_d.host_view()); + options.tensor_ref_d.sync_device(); + + // Create instantiation for device reference gemm kernel + DeviceGemmReference gemm_reference; + + // Launch device reference gemm kernel + gemm_reference( + options.problem_size, + ElementAccumulator(options.alpha), + options.tensor_a.device_ref(), + options.tensor_b.device_ref(), + ElementAccumulator(options.beta), + options.tensor_c.device_ref(), + options.tensor_ref_d.device_ref()); + + // Wait for kernels to finish + CUDA_CHECK(cudaDeviceSynchronize()); + + // Copy output data from reference kernel to host for comparison + options.tensor_ref_d.sync_host(); + + + // + // Evaluate CUTLASS kernels + // + + // Test default operation + if (options.split_k_factor == 1) + { + // Compare basic data-parallel version versus StreamK version using default load-balancing heuristics + Result basic_dp = run("Basic data-parallel GEMM", options); + Result streamk_default = run("StreamK GEMM with default load-balancing", options); + + printf(" Speedup vs Basic-DP: %.3f\n", (basic_dp.avg_runtime_ms / streamk_default.avg_runtime_ms)); + + // Show that StreamK can emulate basic data-parallel GEMM when we set the number of SMs to load-balance across = 1 + options.avail_sms = 1; // Set loadbalancing width to 1 SM (no load balancing) + Result streamk_dp = run("StreamK emulating basic data-parallel GEMM", options); + options.avail_sms = -1; // Reset loadbalancing width to unspecified SMs (i.e., the number of device SMs) + + printf(" Speedup vs Basic-DP: %.3f\n", (basic_dp.avg_runtime_ms / streamk_dp.avg_runtime_ms)); + + options.split_k_factor++; // Increment splitting factor for next evaluation + + } + + // Show that StreamK can emulate "Split-K" with a tile-splitting factor + Result basic_splitk = run( + std::string("Basic split-K GEMM with tile-splitting factor ") + std::to_string(options.split_k_factor), + options); + + Result streamk_splitk = run( + std::string("StreamK emulating Split-K GEMM with tile-splitting factor ") + std::to_string(options.split_k_factor), + options); + + printf(" Speedup vs Basic-SplitK: %.3f\n", (basic_splitk.avg_runtime_ms / streamk_splitk.avg_runtime_ms)); + + return 0; +} diff --git a/examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu b/examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu new file mode 100644 index 0000000000..ed65e58c89 --- /dev/null +++ b/examples/47_ampere_gemm_universal_streamk/ampere_gemm_universal_streamk_broadcast.cu @@ -0,0 +1,738 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*************************************************************************************************** + Example contrasting the Stream-K parallel decomposition for GEMM threadblocks versus the + "classic data-parallel" and "Split-K" decompositions + residual add. + + For more details regarding the Stream-K method, see "Stream-K: Work-centric Parallel Decomposition + for Dense Matrix-Matrix Multiplication on the GPU" (https://arxiv.org/abs/2301.03598) + + Requires NVIDIA Ampere or newer device (SM80+). + + - To lock persistence mode, power (400W), clocks (1005MHz) for evaluation (assumes device 0 and A100) + + cutlass$ sudo nvidia-smi -pm 1 -i 0 + + cutlass$ sudo nvidia-smi -i 0 -pl 400 + + cutlass$ sudo nvidia-smi -i 0 -lgc 1005 + + - Build and run: + + cutlass$ mkdir build + + cutlass$ cd build + + cutlass/build$ cmake .. -DCUTLASS_NVCC_ARCHS=80 + + cutlass/build$ make 47_ampere_gemm_universal_streamk_broadcast + + cutlass/build$ ./examples/47_ampere_gemm_universal_streamk/47_ampere_gemm_universal_streamk_broadcast + + - Reset clocks when done: + + cutlass$ sudo nvidia-smi -rgc + + **************************************************************************************************/ + +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/gemm/device/gemm_universal_with_broadcast.h" +#include "cutlass/gemm/device/gemm_universal_streamk_with_broadcast.h" +#include "cutlass/epilogue/thread/linear_combination_residual_block.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/host/error_metrics.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_foreach.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/tensor_view_io.h" + +#include "cutlass/epilogue/threadblock/fusion/visitors.hpp" +#include "cutlass/gemm/kernel/default_gemm_universal_with_visitor.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" + +#include "helper.h" + + + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations (cutlass_tensorop_h16816gemm_128x128_32x4_nn_align8) +///////////////////////////////////////////////////////////////////////////////////////////////// + +// A matrix configuration +using ElementA = cutlass::half_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = cutlass::half_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::RowMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// C1/C2/D matrix configuration +using ElementC = cutlass::half_t; // Element type for C matrix operands +using LayoutC = cutlass::layout::RowMajor; // Layout type for C matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrices in units of elements (up to 16 bytes) + +// Output matrix configuration +using ElementOutput = cutlass::half_t; // Element type for output matrix operands +using LayoutOutput = cutlass::layout::RowMajor; // Layout type for output matrix operands +// constexpr int AlignmentOutput = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of output matrices in units of elements (up to 16 bytes) + +// Multiply-accumulate blocking/pipelining details +using ElementAccumulator = cutlass::half_t; // Element type for internal accumulation +using ElementCompute = cutlass::half_t; // Element type for compute +using ArchTag = cutlass::arch::Sm80; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using ThreadblockShape = cutlass::gemm::GemmShape<128, 128, 32>; // Threadblock-level tile size (concept: GemmShape) +using WarpShape = cutlass::gemm::GemmShape<64, 64, 32>; // Warp-level tile size (concept: GemmShape) +using InstructionShape = cutlass::gemm::GemmShape<16, 8, 16>; // Instruction-level tile size (concept: GemmShape) +constexpr int NumStages = 4; // Number of global->shared pipeline stages used in the GEMM mainloop +constexpr int EVTEpilogueStages = 1; // Number of epilogue stages in EVT + +// Residual block configuration + +// Epilogue output operator +/// Using LinearCombinationResidualBlock +/// Models a residual block of the form: UnaryOp(BinaryOp(BinaryOp(ActivationOp(TensorOp(X) + bias), residual1), residual2)) +using EpilogueOp = cutlass::epilogue::thread::LinearCombinationResidualBlock< + ElementOutput, // Element type for output matrix + ElementAccumulator, // Element type from internal accumulation + ElementCompute, // Element type from internal accumulation + ElementC, // Element type for C1/C2/D matrix operands + AlignmentC, // Memory access granularity of C and D matrix in units of elements + cutlass::epilogue::thread::Identity, // Activation + cutlass::plus, // Binary operation 1 + cutlass::epilogue::thread::Identity, // Unary operation + cutlass::plus // Binary operation 2 + >; + +// Reference device GEMM implementation type +using DeviceGemmReference = cutlass::reference::device::Gemm< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + ElementAccumulator>; + +// Classic data-parallel device GEMM implementation type +using DeviceGemmBasic = cutlass::gemm::device::GemmUniversalWithBroadcast< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOp, + cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, + NumStages, + AlignmentA, + AlignmentB>; + +// StreamK device GEMM implementation type with EVT +using namespace cute; + +using OutputTileThreadMap = cutlass::epilogue::threadblock::OutputTileThreadLayout< + ThreadblockShape, + WarpShape, + ElementC, + AlignmentC, + EVTEpilogueStages +>; + +using Accum = cutlass::epilogue::threadblock::VisitorAccFetch; + +using Bias = cutlass::epilogue::threadblock::VisitorRowBroadcast< + OutputTileThreadMap, ElementC, + cute::Stride<_0, _1, int32_t> // StrideMNL +>; + +using C1 = cutlass::epilogue::threadblock::VisitorAuxLoad< + OutputTileThreadMap, ElementC, + cute::Stride // StrideMNL +>; + +using C2 = cutlass::epilogue::threadblock::VisitorAuxLoad< + OutputTileThreadMap, ElementC, + cute::Stride // StrideMNL +>; + +using Compute0 = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::plus, ElementCompute, ElementCompute, + cutlass::FloatRoundStyle::round_to_nearest +>; + +using EVTCompute0 = cutlass::epilogue::threadblock::Sm80EVT< + Compute0, + Accum, + Bias>; + +using Compute1 = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::plus, ElementCompute, ElementCompute, + cutlass::FloatRoundStyle::round_to_nearest +>; + +using EVTCompute1 = cutlass::epilogue::threadblock::Sm80EVT< + Compute1, + EVTCompute0, + C1>; + +using Compute2 = cutlass::epilogue::threadblock::VisitorCompute< + cutlass::plus, ElementOutput, ElementCompute, + cutlass::FloatRoundStyle::round_to_nearest +>; + +using EVTCompute2 = cutlass::epilogue::threadblock::Sm80EVT< + Compute2, + EVTCompute1, + C2>; + +using D = cutlass::epilogue::threadblock::VisitorAuxStore< + OutputTileThreadMap, ElementOutput, cutlass::FloatRoundStyle::round_to_nearest, + cute::Stride // StrideMNL +>; + +using EVTD = cutlass::epilogue::threadblock::Sm80EVT< + D, + EVTCompute2>; + +using EVTKernelStreamK = + typename cutlass::gemm::kernel::DefaultGemmWithVisitor< + ElementA, LayoutA, cutlass::ComplexTransform::kNone, AlignmentA, + ElementB, LayoutB, cutlass::ComplexTransform::kNone, AlignmentB, + ElementC, LayoutC, AlignmentC, + ElementAccumulator, + ElementCompute, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm80, + ThreadblockShape, + WarpShape, + InstructionShape, + EVTD, + cutlass::gemm::threadblock::ThreadblockSwizzleStreamK, + NumStages, + cutlass::arch::OpMultiplyAdd, + EVTEpilogueStages +>::GemmKernel; + +using DeviceGemmStreamK = cutlass::gemm::device::GemmUniversalAdapter; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Result structure +struct Result +{ + double avg_runtime_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + Result( + double avg_runtime_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess) + : + avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(true) + {} + +}; + + +/// Command line options parsing +struct Options +{ + std::string command_name; + bool help; + cutlass::gemm::GemmCoord problem_size; + float alpha; + float beta; + int split_k_factor; + int avail_sms; + int iterations; + bool real; + + cutlass::HostTensor tensor_a; + cutlass::HostTensor tensor_b; + cutlass::HostTensor tensor_c1; + cutlass::HostTensor tensor_c2; + cutlass::HostTensor tensor_d; + cutlass::HostTensor tensor_ref_d; + cutlass::HostTensor tensor_Vector; + // cutlass::HostTensor tensor_Tensor; + + Options(std::string command_name) : + command_name(command_name), + help(false), + problem_size({2048, 2048, 2048}), + alpha(1.0f), + beta(1.0f), + split_k_factor(1), + avail_sms(-1), // Number of device SMs to use is unlimited + real(false), + iterations(10000) + {} + + bool valid() const + { + return true; + } + + void parse(int argc, char const **args) + { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + } + + cmd.get_cmd_line_argument("m", problem_size.m()); + cmd.get_cmd_line_argument("n", problem_size.n()); + cmd.get_cmd_line_argument("k", problem_size.k()); + cmd.get_cmd_line_argument("alpha", alpha); + cmd.get_cmd_line_argument("beta", beta); + cmd.get_cmd_line_argument("split", split_k_factor); + cmd.get_cmd_line_argument("iterations", iterations); + real = cmd.check_cmd_line_flag("real"); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const + { + out + << "Performs a GEMM computation.\n" + << "\n" + << "Options:\n" + << "\n" + << " --help If specified, displays this usage statement.\n\n" + << " --m= GEMM M dimension\n" + << " --n= GEMM N dimension\n" + << " --k= GEMM K dimension\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --split= Split-K factor to emulate\n\n" + << " --real If specified, initializes with real values instead of whole numbers. Errors are to be expected.\n\n" + << " --iterations= Number of profiling iterations to perform.\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << command_name << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + return 2.0 * double(problem_size.product()) / double(1.0e9) / runtime_s; + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Populates a DeviceGemmBasic::Arguments structure from the given commandline options +typename DeviceGemmBasic::Arguments args_from_options( + const DeviceGemmBasic &device_gemm, + const Options &options, + cutlass::HostTensor &tensor_a, + cutlass::HostTensor &tensor_b, + cutlass::HostTensor &tensor_c1, + cutlass::HostTensor &tensor_c2, + cutlass::HostTensor &tensor_d, + cutlass::HostTensor &tensor_Vector /*, + cutlass::HostTensor &tensor_Tensor */ + ) +{ + return typename DeviceGemmBasic::Arguments( + cutlass::gemm::GemmUniversalMode::kGemm, // universal mode + options.problem_size, // problem_size + options.split_k_factor, // batch count / splitk slices + { // epilogue parameters + ElementAccumulator(options.alpha), + ElementAccumulator(options.beta) + }, + tensor_a.device_data(), // ptr_A + tensor_b.device_data(), // ptr_B + tensor_c1.device_data(), // ptr_C1 + tensor_c2.device_data(), // ptr_C2 + tensor_d.device_data(), // ptr_D + tensor_Vector.device_data(), // ptr_Vector + /* tensor_Tensor.device_data(), */nullptr,// ptr_Tensor + options.problem_size.mk().product(), // batch_stride_A + options.problem_size.nk().product(), // batch_stride_B + options.problem_size.mn().product(), // batch_stride_C1 + options.problem_size.mn().product(), // batch_stride_C2 + options.problem_size.mn().product(), // batch_stride_D + options.problem_size.mn().product(), // batch_stride_Vector + options.problem_size.mn().product(), // batch_stride_Tensor + tensor_a.layout().stride(0), // stride_a + tensor_b.layout().stride(0), // stride_b + tensor_c1.layout().stride(0), // stride_c1 + tensor_c2.layout().stride(0), // stride_c2 + tensor_d.layout().stride(0), // stride_d + /*tensor_Vector.layout().stride(0)*/0, // stride_Vector + /*tensor_Tensor.layout().stride(0)*/0); // stride_Tensor +} + +/// Populates a DeviceGemmStreamK::Arguments structure from the given commandline options +typename DeviceGemmStreamK::Arguments args_from_options( + const DeviceGemmStreamK &device_gemm, + const Options &options, + cutlass::HostTensor &tensor_a, + cutlass::HostTensor &tensor_b, + cutlass::HostTensor &tensor_c1, + cutlass::HostTensor &tensor_c2, + cutlass::HostTensor &tensor_d, + cutlass::HostTensor &tensor_Vector/*, + cutlass::HostTensor &tensor_Tensor*/ + ) +{ + typename EVTD::Arguments callback_args{ + { + { + { + {}, // Accum + {tensor_Vector.device_data(), ElementC(0), {_0{}, _1{}, int32_t(options.problem_size.n())}}, // Bias + {} // Compute0 + }, // EVTCompute0 + {tensor_c1.device_data(), ElementC(0), {options.problem_size.n(), _1{}, options.problem_size.mn().product()}}, // C1 + {} // Compute1 + }, // EVTCompute1 + {tensor_c2.device_data(), ElementC(0), {options.problem_size.n(), _1{}, options.problem_size.mn().product()}}, // C2 + {} // Compute2 + }, // EVTCompute2 + {tensor_d.device_data(), {options.problem_size.n(), _1{}, options.problem_size.mn().product()}}, // D + }; // EVTD + + return typename DeviceGemmStreamK::Arguments( + cutlass::gemm::GemmUniversalMode::kGemm, // universal mode + options.problem_size, // problem_size + options.split_k_factor, // batch count / splitk slices + callback_args, // argument of EVT callbacks + tensor_a.device_data(), // ptr_A + tensor_b.device_data(), // ptr_B + nullptr, // ptr_C (unused) + nullptr, // ptr_D (unused) + options.problem_size.mk().product(), // batch_stride_A + options.problem_size.nk().product(), // batch_stride_B + 0, // batch_stride_C (unused) + 0, // batch_stride_D (unused) + tensor_a.layout().stride(0), // stride_a + tensor_b.layout().stride(0), // stride_b + 0, // stride_c (unused) + 0, // stride_d (unused) + options.avail_sms); // avail_sms +} + +/// Execute a given example GEMM computation +template +Result run(std::string description, Options &options) +{ + // Display test description + std::cout << std::endl << description << std::endl; + + // Zero-initialize test output matrix D + cutlass::reference::host::TensorFill(options.tensor_d.host_view()); + options.tensor_d.sync_device(); + + // Instantiate CUTLASS kernel depending on templates + DeviceGemmT device_gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of DeviceGemmT + auto arguments = args_from_options(device_gemm, options, + options.tensor_a, options.tensor_b, options.tensor_c1, options.tensor_c2, options.tensor_d, + options.tensor_Vector/*, options.tensor_Tensor*/); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = DeviceGemmT::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check the problem size is supported or not + CUTLASS_CHECK(device_gemm.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(device_gemm.initialize(arguments, workspace.get())); + + // Correctness / Warmup iteration + CUTLASS_CHECK(device_gemm()); + + // Copy output data from CUTLASS and reference kernel to host for comparison + options.tensor_d.sync_host(); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + result.passed = cutlass::reference::host::TensorEquals( + options.tensor_d.host_view(), + options.tensor_ref_d.host_view()); + + double err = cutlass::reference::host::TensorRelativeErrorMetric( + options.tensor_d.host_view(), + options.tensor_ref_d.host_view()); + + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << " \t Relative error: " << err << std::endl; + + // Run profiling loop + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(device_gemm()); + } + timer.stop(); + + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPs: " << result.gflops << std::endl; + } + + // TODO: uncomment when results match + //if (!result.passed) { + // exit(-1); + //} + + return result; +} + + +/// Program entrypoint +int main(int argc, const char **argv) +{ + // CUTLASS must be compiled with CUDA 11.0 Toolkit to run these examples. + if (!(__CUDACC_VER_MAJOR__ >= 11)) { + std::cerr << "Ampere Tensor Core operations must be compiled with CUDA 11.0 Toolkit or later." << std::endl; + + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + // Current device must must have compute capability at least 80 + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + if (!((props.major * 10 + props.minor) >= 80)) + { + std::cerr << "Ampere Tensor Core operations must be run on a machine with compute capability at least 80." + << std::endl; + + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + // Parse commandline options + Options options("ampere_streamk_broadcast_gemm"); + options.parse(argc, argv); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + std::cout << + options.iterations << " timing iterations of " << + options.problem_size.m() << " x " << + options.problem_size.n() << " x " << + options.problem_size.k() << " matrix-matrix multiply" << std::endl; + + if (!options.valid()) { + std::cerr << "Invalid problem." << std::endl; + return -1; + } + + + // + // Initialize GEMM datasets + // + + // Initialize tensors using CUTLASS helper functions + options.tensor_a.resize(options.problem_size.mk()); // <- Create matrix A with dimensions M x K + options.tensor_b.resize(options.problem_size.kn()); // <- Create matrix B with dimensions K x N + options.tensor_c1.resize(options.problem_size.mn()); // <- Create matrix C1 with dimensions M x N + options.tensor_c2.resize(options.problem_size.mn()); // <- Create matrix C2 with dimensions M x N + options.tensor_d.resize(options.problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from CUTLASS kernel + options.tensor_ref_d.resize(options.problem_size.mn()); // <- Create matrix D with dimensions M x N used to store output from reference kernel + options.tensor_Vector.resize({1, options.problem_size.n()}); // <- Create broadcast vector with dimensions N x 1 + // options.tensor_Tensor.resize(options.problem_size.mn()); // <- Create T matrix with dimensions M x N + + int _init_bits = options.real ? -1 : 0; + + // Fill matrix A on host with uniform-random data [-2, 2] + cutlass::reference::host::TensorFillRandomUniform( + options.tensor_a.host_view(), + 1, + ElementA(2), + ElementA(-2), _init_bits); + + // Fill matrix B on host with uniform-random data [-2, 2] + cutlass::reference::host::TensorFillRandomUniform( + options.tensor_b.host_view(), + 1, + ElementB(2), + ElementB(-2), _init_bits); + + // Fill matrix C1 on host with uniform-random data [-2, 2] + cutlass::reference::host::TensorFillRandomUniform( + options.tensor_c1.host_view(), + 1, + ElementC(2), + ElementC(-2), _init_bits); + + // Fill matrix C2 on host with uniform-random data [-2, 2] + cutlass::reference::host::TensorFillRandomUniform( + options.tensor_c2.host_view(), + 1, + ElementC(2), + ElementC(-2), _init_bits); + + cutlass::reference::host::TensorFillRandomUniform( + options.tensor_Vector.host_view(), + 1, + ElementC(2), + ElementC(-2), _init_bits); + + // + // Compute reference output + // + + // Copy data from host to GPU + options.tensor_a.sync_device(); + options.tensor_b.sync_device(); + options.tensor_c1.sync_device(); + options.tensor_c2.sync_device(); + options.tensor_Vector.sync_device(); + // options.tensor_Tensor.sync_device(); + + // Zero-initialize reference output matrix D + cutlass::reference::host::TensorFill(options.tensor_ref_d.host_view()); + options.tensor_ref_d.sync_device(); + + // Create instantiation for device reference gemm kernel + DeviceGemmReference gemm_reference; + + // Launch device reference gemm kernel + gemm_reference( + options.problem_size, + ElementAccumulator(options.alpha), + options.tensor_a.device_ref(), + options.tensor_b.device_ref(), + ElementAccumulator(options.beta), + options.tensor_c1.device_ref(), + options.tensor_ref_d.device_ref()); + + // Wait for kernels to finish + CUDA_CHECK(cudaDeviceSynchronize()); + + // Copy output data from reference kernel to host for comparison + options.tensor_ref_d.sync_host(); + + // Add broadcast vector (without multiplier) + // This is only possible because BinaryOp is addition, and UnaryOps are identity. + // This makes the addition of broadcast vector commutable. + /// identity(plus(identity(alpha * (a * b) + v), beta * c)) == + /// alpha * a * b + v + beta * c == + /// (alpha * a * b + beta * c) + v == + /// GEMM(a, b, c) + v + // Vector broadcast on host + for (int i=0; i < options.problem_size.m(); ++i) { + for (int j=0; j < options.problem_size.n(); ++j) { + options.tensor_ref_d.host_view().ref().at({i, j}) += options.tensor_Vector.host_view().ref().at({0, j}); + options.tensor_ref_d.host_view().ref().at({i, j}) += options.tensor_c2.host_view().ref().at({i, j}); + } + } + + // Sync back with device just in case + options.tensor_ref_d.sync_device(); + + // + // Evaluate CUTLASS kernels + // + + // Test default operation + if (options.split_k_factor == 1) + { + // Compare basic data-parallel version versus StreamK version using default load-balancing heuristics + Result basic_dp = run("Basic data-parallel GEMM", options); + Result streamk_default = run("StreamK GEMM with default load-balancing", options); + + printf(" Speedup vs Basic-DP: %.3f\n", (basic_dp.avg_runtime_ms / streamk_default.avg_runtime_ms)); + + // Show that StreamK can emulate basic data-parallel GEMM when we set the number of SMs to load-balance across = 1 + options.avail_sms = 1; // Set loadbalancing width to 1 SM (no load balancing) + Result streamk_dp = run("StreamK emulating basic data-parallel GEMM", options); + options.avail_sms = -1; // Reset loadbalancing width to unspecified SMs (i.e., the number of device SMs) + + printf(" Speedup vs Basic-DP: %.3f\n", (basic_dp.avg_runtime_ms / streamk_dp.avg_runtime_ms)); + + options.split_k_factor++; // Increment splitting factor for next evaluation + + } + + // Show that StreamK can emulate "Split-K" with a tile-splitting factor + Result basic_splitk = run( + std::string("Basic split-K GEMM with tile-splitting factor ") + std::to_string(options.split_k_factor), + options); + + Result streamk_splitk = run( + std::string("StreamK emulating Split-K GEMM with tile-splitting factor ") + std::to_string(options.split_k_factor), + options); + + printf(" Speedup vs Basic-SplitK: %.3f\n", (basic_splitk.avg_runtime_ms / streamk_splitk.avg_runtime_ms)); + + return 0; +} diff --git a/examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu b/examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu new file mode 100644 index 0000000000..164c785e01 --- /dev/null +++ b/examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm.cu @@ -0,0 +1,510 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Simple Hopper GEMM example using CUTLASS 3.0 APIs for NVIDIA Hopper architecture + + This example demonstrate a simple way to instantiate and run a TF32 GEMM using the new CUTLASS 3.0 + APIs on NVIDIA Hopper architecture. New features that will be showcased in this example are as follows: + + 1. NVIDIA Hopper architecture introduces a new series of tensor core instructions (GMMA) + which are more efficient than the Ampere tensor core instructions. + + 2. NVIDIA Hopper architecture includes new Tensor Memory Accelerator (TMA) unit to transfer large + blocks of data efficiently between global memory and shared memory. TMA also supports asynchronous + copies between thread blocks in a cluster. Another advantage is that TMA can load in FP32 data and + convert them implicitly to TF32. + + 3. This example uses the Warp Specialized kernel design (see /media/docs/efficient_gemm.md for details). + + 4. A simple way to tune the CTA rasterization direction and swizzle pattern of Hopper kernels. Both the + CTA rasterization direction and swizzle pattern impact cross-CTA locality of accesses. By tuning we can + improve performance. + + Examples: + + $ ./examples/48_hopper_warp_specialized_gemm/48_hopper_warp_specialized_gemm --m=2048 --n=2048 --k=2048 --rasterization=N --swizzle=2 +*/ + +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/device/tensor_fill.h" + +#include "helper.h" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// + +// A matrix configuration +using ElementA = float; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = float; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// C/D matrix configuration +using ElementC = float; // Element type for C and D matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape = Shape<_128,_128,_32>; // Threadblock-level tile size +using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster +using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size +using KernelSchedule = cutlass::gemm::collective::KernelScheduleAuto; // Kernel to launch based on the default setting in the Collective Builder + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementC, LayoutC, AlignmentC, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloop, + CollectiveEpilogue +>; + +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + +// Reference device GEMM implementation type +using DeviceGemmReference = cutlass::reference::device::Gemm< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + ElementAccumulator>; + +using StrideA = typename Gemm::GemmKernel::StrideA; +using StrideB = typename Gemm::GemmKernel::StrideB; +using StrideC = typename Gemm::GemmKernel::StrideC; +using StrideD = typename Gemm::GemmKernel::StrideD; + +// +// Data members +// + +/// Initialization +StrideA stride_A; +StrideB stride_B; +StrideC stride_C; +StrideD stride_D; +uint64_t seed; + +cutlass::DeviceAllocation block_A; +cutlass::DeviceAllocation block_B; +cutlass::DeviceAllocation block_C; +cutlass::DeviceAllocation block_D; +cutlass::DeviceAllocation block_ref_D; + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90Params::RasterOrderOptions; + +// Command line options parsing +struct Options { + + bool help; + + float alpha, beta; + int iterations; + int m, n, k; + RasterOrderOptions raster; + int swizzle; + + Options(): + help(false), + m(5120), n(4096), k(4096), + alpha(1.f), beta(0.f), + iterations(1000), + raster(RasterOrderOptions::Heuristic), + swizzle(1) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations); + + char raster_char; + cmd.get_cmd_line_argument("raster", raster_char); + + if (raster_char == 'N' || raster_char == 'n') { + raster = RasterOrderOptions::AlongN; + } + else if (raster_char == 'M' || raster_char == 'm') { + raster = RasterOrderOptions::AlongM; + } + else if (raster_char == 'H' || raster_char == 'h') { + raster = RasterOrderOptions::Heuristic; + } + + cmd.get_cmd_line_argument("swizzle", swizzle, 1); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "48_hopper_warp_specialized_gemm\n\n" + << " Hopper FP32 GEMM using a Warp Specialized kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --raster= CTA Rasterization direction (N for along N, M for along M, and H for heuristic)\n\n" + << " --swizzle= CTA Rasterization swizzle\n\n" + << " --iterations= Number of profiling iterations to perform.\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "48_hopper_warp_specialized_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +/// Result structure +struct Result +{ + double avg_runtime_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + Result( + double avg_runtime_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess) + : + avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false) + {} + +}; + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_block( + cutlass::DeviceAllocation& block, + uint64_t seed=2023) { + + Element scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = Element(2); + scope_min = Element(0); + } else if (bits_input <= 8) { + scope_max = Element(2); + scope_min = Element(-2); + } else { + scope_max = Element(8); + scope_min = Element(-8); + } + + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, scope_max, scope_min, 0); + + return true; +} + +/// Initialize operands to be used in the GEMM and reference GEMM +void initialize(const Options &options) { + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, {options.m, options.k, 1}); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, {options.n, options.k, 1}); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, {options.m, options.n, 1}); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, {options.m, options.n, 1}); + + block_A.reset(options.m * options.k); + block_B.reset(options.k * options.n); + block_C.reset(options.m * options.n); + block_D.reset(options.m * options.n); + block_ref_D.reset(options.m * options.n); + + initialize_block(block_A, seed + 2023); + initialize_block(block_B, seed + 2022); + initialize_block(block_C, seed + 2021); +} + +/// Populates a Gemm::Arguments structure from the given commandline options +typename Gemm::Arguments args_from_options(const Options &options) +{ + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {options.m, options.n, options.k}, + {block_A.get(), stride_A, block_B.get(), stride_B}, + {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D} + }; + + arguments.scheduler.raster_order = options.raster; + // The tile scheduler will swizzle up to 8 and with the nearest multiple of 2 (i.e., 1, 2, 4, and 8) + arguments.scheduler.max_swizzle_size = options.swizzle; + + return arguments; +} + +bool verify(const Options &options) { + cutlass::TensorRef ref_A(block_A.get(), Gemm::LayoutA::packed({options.m, options.k})); + cutlass::TensorRef ref_B(block_B.get(), Gemm::LayoutB::packed({options.k, options.n})); + cutlass::TensorRef ref_C(block_C.get(), Gemm::LayoutC::packed({options.m, options.n})); + cutlass::TensorRef ref_D(block_ref_D.get(), Gemm::LayoutD::packed({options.m, options.n})); + + // + // Compute reference output + // + + // Create instantiation for device reference gemm kernel + DeviceGemmReference gemm_reference; + + // Launch device reference gemm kernel + gemm_reference( + {options.m, options.n, options.k}, + ElementAccumulator(options.alpha), + ref_A, + ref_B, + ElementAccumulator(options.beta), + ref_C, + ref_D); + + // Wait for kernel to finish + CUDA_CHECK(cudaDeviceSynchronize()); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + bool passed = cutlass::reference::device::BlockCompareEqual(block_ref_D.get(), block_D.get(), block_D.size()); + + return passed; +} + +/// Execute a given example GEMM computation +template +int run(Options &options) +{ + initialize(options); + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run()); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + result.passed = verify(options); + + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + + if (!result.passed) { + exit(-1); + } + + // Run profiling loop + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm.run()); + } + timer.stop(); + + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + std::string raster = "Heuristic"; + + if (options.raster == RasterOrderOptions::AlongN) { + raster = "Along N"; + } + else if (options.raster == RasterOrderOptions::AlongM) { + raster = "Along M"; + } + + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl; + std::cout << " Rasterization: " << raster << " with a maximum CTA swizzle of " << options.swizzle << std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << result.gflops << std::endl; + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example + // and must have compute capability at least 90. + if (__CUDACC_VER_MAJOR__ < 12) { + std::cerr << "This example requires CUDA 12 or newer.\n"; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (props.major < 9) { + std::cerr + << "This example requires a GPU of NVIDIA's Hopper Architecture or " + << "later (compute capability 90 or greater).\n"; + return 0; + } + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Evaluate CUTLASS kernels + // + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + run(options); +#endif + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/48_hopper_warp_specialized_gemm/CMakeLists.txt b/examples/48_hopper_warp_specialized_gemm/CMakeLists.txt new file mode 100644 index 0000000000..903da1ea6f --- /dev/null +++ b/examples/48_hopper_warp_specialized_gemm/CMakeLists.txt @@ -0,0 +1,35 @@ + +# Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + + +cutlass_example_add_executable( + 48_hopper_warp_specialized_gemm + 48_hopper_warp_specialized_gemm.cu + ) diff --git a/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu b/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu new file mode 100644 index 0000000000..1e820ddb47 --- /dev/null +++ b/examples/49_hopper_gemm_with_collective_builder/49_collective_builder.cu @@ -0,0 +1,652 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Hopper GEMM example leveraging collective operation builders. + + This example showcases the use of CUTLASS's CollectiveBuilder to easily construct performant kernels + targeting the NVIDIA Hopper architecture. + + Background and motivation + ------------------------- + CUTLASS kernels are highly parameterizable via template parameters. To ease the selection of template + parameters, CUTLASS 2 leveraged DefaultGemmConfigurations. Given a small set of parameters, such as + the data types of operands and the compute capability of the GPU, DefaultGemmConfigurations defined sensible + defaults for the many other parameters to the kernel (e.g., warp shape, stage count). + + However, DefaultGemmConfigurations leave multiple opportunities for improvement, which are addressed + in CUTLASS 3: + (1) DefaultGemmConfigurations do not allow one to use a more-performant set of parameters without + specifying every parameter. For example, the DefaultGemmConfigurations for GEMMs targeting + Ampere specify that three pipeline stages should be used regardless of the sizes of operands. + If one wished to increase this value, one would also need to specify all other template parameters. + This leaves a gap between a high-level ease-of-use interface and a lower-level detailed interface. + (2) A new DefaultGemmConfiguration was required for each combination of operand types, GPU architecture, + and operation type (e.g., Tensor Core or SIMT). This led to increased code size to cover each unique + configuration and a lack of extensibility from one DefaultGemmConfiguration to another. + + Alongside these opportunities for improvement, the Hopper architecture offers new features that increase + the number of valid configurations of a kernel. In addition to the many template parameters already available + in CUTLASS 2 kernels, CUTLASS 3 kernels targeting Hopper also have various scheduling modes to select from that control: + (1) how data is to be loaded (e.g., using the Hopper TMA feature or Ampere cp.async) + (2) how work is to be divided among warps in a thread block (e.g., whether to use "warp specialization") + (3) whether persistent thread blocks should be used + This increased configuration space further motivates rethinking DefaultGemmConfigurations. + + Introduction to the CollectiveBuilder + ------------------------------------- + CUTLASS 3 introduces the CollectiveBuilder to further ease the process of selecting template parameters + for kernels targeting Hopper. Similar to the DefaultGemmConfigurations used in CUTLASS 2, the CollectiveBuilder + takes in a small set of template parameters (e.g., the data types of operands A and B). It then automatically + determines the data loading strategy to use depending on whether the Hopper TMA feature can be used with the provided + parameters. If one does not indicate a particular scheduling policy or stage count to use (by using `Auto` template + parameters), the CollectiveBuilder will also automatically select these. + + Unlike DefaultGemmConfigurations a partial specialization of the CollectiveBuilder is not needed for many + configurations of operand types. Instead the CollectiveBuilder "builds" a configuration based on generic + properties of the specified operands, layouts, and other parameters. For example, when the stage count + is set to `Auto`, the CollectiveBuilder may automatically calculate the maximum number of stages that + will fit in shared memory given the types of operands and the thread block shape, rather than simply using + a single default value. + + CUTLASS 3.x provides builders for both collective mainloops and epilogues. The particular implementation of + the collective is specified via the schedule tags that corresond to the underlying collective's + dispatch policy. `gemm::collective::KernelScheduleAuto` and `epilogue::collective::EpilogueScheduleAuto` + are special cases of these schedules that allow the builder to also decide the dispatch policy for you, + therefore letting the builder pick the collective specialization. + + CUTLASS builders make an attempt to pick the best schedule when `Auto` is provided such that the + assembled collectives have the best performance, but this is not a guarantee. A user relying on `Auto` + may get a free performance upgrade with newer CUTLASS releases in case we can provide more optimized + implementations that the builder can transparently assemble for `Auto`. But a user should not rely on + `Auto` if they require a specific scheduling policy and/or stage count to be used. + + If a user decides to let the builders pick the collective specialization via `Auto` schedules, + they must be used for both mainloop and epilogue alike to ensure compatibility between the + chosen collectives. Additionally, if a user chooses to opt in to a specific schedule, non-`Auto` + schedules must be used for both mainloop and epilogue builder schedules, and these schedules + must be compatible. + + One does not need to use the CollectiveBuilder to declare CUTLASS 3 kernels; one can still provide + every template parameter to the `gemm::collective::CollectiveMma`. Specifying every template parameter + in this manner remains the primary API for using CUTLASS 3 kernels. `CollectiveBuilder`s are + simply meant to be a convenience interface. + + Details of this example + ----------------------- + This example walks through the use of the CollectiveBuilder with various schedules and stage counts specified. + This example also illustrates how CUTLASS 3 GEMMs targeting Hopper automatically support batched GEMMs by simply + extending the problem size with an additional tensor rank. + + CUTLASS 3.2 provides initial support for epilogue visitor trees (EVT) for the TMA warp-specialized collective. + EVTs allow users to define their own customized epilogue fusion patterns without having to write a new + collective epilogue. This is done by representing the fusion as a compute graph, where each node is one of a + fundamental set of load, store, or compute operations. These operations are either elementwise for tensor + inputs/outputs, broadcasts for vector/scalar inputs, or reductions for vector/scalar outputs. + This example shows how users can define their own custom EVT and use it with the CollectiveBuilder. + + Example usage: + $ ./examples/49_hopper_with_collective_builder/49_collective_builder \ + --m=2048 --n=2048 --k=2048 --l=2 +*/ + +#include + +#include "cute/tensor.hpp" + +#include "cutlass/cutlass.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler.hpp" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/device/tensor_fill.h" + +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Command line options parsing +struct Options { + + bool help; + bool error; + + int m, n, k, l; + float alpha, beta; + + Options(): + help(false), + error(false), + m(2048), n(2048), k(2048), l(1), + alpha(1.f), beta(0.f) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m, 2048); + cmd.get_cmd_line_argument("n", n, 2048); + cmd.get_cmd_line_argument("k", k, 2048); + cmd.get_cmd_line_argument("l", l, 1); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "49_hopper_with_collective_builder\n\n" + << " This example showcases the use of CUTLASS's collective operation builders to easily construct\n" + << " performant kernels targeting NVIDIA's Hopper architecture.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the L extent (batch count) of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n"; + + return out; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_block( + cutlass::DeviceAllocation& block, + uint64_t seed=2023) { + + Element scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, scope_max, scope_min, 0); + + return true; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +// Wrapper to construct, run, and verify a GEMM. This example showcases CUTLASS's collective +// operation builders by specializing the GEMM only on the kernel schedule it will use and the +// number of pipeline stages. +// +// One can use a special `Auto` type that tells the CollectiveBuilder +// to select an appropriate value on its own. The CollectiveBuilder will attempt to select +// configurations that will result in the most-performant kernel, but this is not a guarantee. +// +// If relying on 'Auto' schedules, all builders must use the 'Auto' schedule to ensure compatiblity. +// For example, if `KernelScheduleAuto` is used for the mainloop builder, `EpilogueScheduleAuto` must +// be used for the epilogue builder. +// +// Furthermore, if an override schedule is selected, both epilogue and mainloop schedules must +// be specifically opt into a compatible selection. +// +// Behavior of the CollectiveBuilder with `Auto` types is subject to change in future releases +// -- do not rely on `Auto` if you require a specific scheduling policy. +template < + // Type of kernel schedule to generate + class MainloopScheduleType = cutlass::gemm::collective::KernelScheduleAuto, + // Type of epilogue schedule to generate + class EpilogueScheduleType = cutlass::epilogue::collective::EpilogueScheduleAuto, + // Number of pipeline stages to use + class StageCountType = cutlass::gemm::collective::StageCountAuto, + // Type of tile scheduler to use + class TileSchedulerType = cutlass::gemm::PersistentScheduler, + // Do we use custom epilogue visitor tree (EVT) fusion + bool UseCustomEVT = false +> +struct ExampleRunner { + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using LayoutD = cutlass::layout::ColumnMajor; + + using ElementA = cutlass::half_t; + using ElementB = cutlass::half_t; + using ElementC = cutlass::half_t; + using ElementD = cutlass::half_t; + using ElementAccumulator = float; + using ElementCompute = float; + using ElementScalar = float; + + // 16B alignment lets us use TMA + static constexpr int AlignmentA = 16 / sizeof(ElementA); + static constexpr int AlignmentB = 16 / sizeof(ElementB); + static constexpr int AlignmentC = 16 / sizeof(ElementC); + static constexpr int AlignmentD = 16 / sizeof(ElementD); + + static_assert(not UseCustomEVT || + (cute::is_same_v || + cute::is_same_v), + "Epilogue visitor trees are currently only supported by the TMA warp-specialized epilogue"); + static constexpr auto RoundStyle = cutlass::FloatRoundStyle::round_to_nearest; + + // EVTs can be constructed by composing the fundamental load/store/compute visitor operations defined in include/cutlass/epilogue/fusion + // For more complex examples of EVT construction please refer to include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp + using CustomEVT = // alpha * acc + beta * C + cutlass::epilogue::fusion::Sm90EVT, // beta * C + (alpha * acc) + cutlass::epilogue::fusion::Sm90ScalarBroadcast, // beta + cutlass::epilogue::fusion::Sm90SrcFetch, // C + cutlass::epilogue::fusion::Sm90EVT, // alpha * acc + cutlass::epilogue::fusion::Sm90ScalarBroadcast, // alpha + cutlass::epilogue::fusion::Sm90AccFetch // acc + > + >; + + // A predefined set of fusion operations (implemented with EVT) are supported by the TMA warp-specialized epilogue. + // Users can select one of these operations by passing one of the tags defined in include/cutlass/epilogue/fusion/operations.hpp + // to the CollectiveBuilder. This frees the user from having to compute additional parameters such as stage counts and copy atoms/layouts. + // These tags also provide additional metadata that can be queried at compile time. + using DefaultOperation = cutlass::epilogue::fusion::LinearCombination; + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_128,_128,_64>, Shape<_1,_1,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementCompute, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + EpilogueScheduleType, + cute::conditional_t + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + Shape<_128,_128,_64>, Shape<_2,_1,_1>, + cute::conditional_t, + cutlass::gemm::collective::StageCountAutoCarveout(sizeof(typename CollectiveEpilogue::SharedStorage))>, + StageCountType>, + MainloopScheduleType + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + CollectiveEpilogue, + TileSchedulerType + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + + using LayoutTagA = cutlass::gemm::detail::StrideToLayoutTagA_t; + using LayoutTagB = cutlass::gemm::detail::StrideToLayoutTagB_t; + using LayoutTagC = cutlass::gemm::detail::StrideToLayoutTagC_t; + using LayoutTagD = cutlass::gemm::detail::StrideToLayoutTagC_t; + + // + // Data members + // + + /// Initialization + StrideA stride_A; + StrideB stride_B; + StrideC stride_C; + StrideD stride_D; + uint64_t seed = 0; + + cutlass::DeviceAllocation block_A; + cutlass::DeviceAllocation block_B; + cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_D; + cutlass::DeviceAllocation block_ref_D; + + // + // Methods + // + + bool verify(const ProblemShapeType& problem_size, float alpha, float beta) { + auto [M, N, K, L] = problem_size; + + cutlass::TensorRef ref_A(block_A.get(), Gemm::LayoutA::packed({M, K})); + cutlass::TensorRef ref_B(block_B.get(), Gemm::LayoutB::packed({K, N})); + cutlass::TensorRef ref_C(block_C.get(), Gemm::LayoutC::packed({M, N})); + cutlass::TensorRef ref_D(block_ref_D.get(), Gemm::LayoutD::packed({M, N})); + + cutlass::reference::device::GemmComplex( + {M, N, K}, + ElementScalar(alpha), + ref_A, + cutlass::ComplexTransform::kNone, + ref_B, + cutlass::ComplexTransform::kNone, + ElementScalar(beta), + ref_C, + ref_D, + ElementAccumulator(0), + L, // batch_count + M * K, // batch_stride_A + K * N, // batch_stride_B + M * N, // batch_stride_C + M * N // batch_stride_D + ); + + cudaError_t result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + std::cerr << "Reference kernel failed. Last CUDA error: " + << cudaGetErrorString(result) << std::endl; + return false; + } + + // Check if output from CUTLASS kernel and reference kernel are equal or not + bool passed = cutlass::reference::device::BlockCompareEqual(block_ref_D.get(), block_D.get(), block_D.size()); + + return passed; + } + + /// Initialize operands to be used in the GEMM and reference GEMM + void initialize(const ProblemShapeType& problem_size) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + + block_A.reset(M * K * L); + block_B.reset(K * N * L); + block_C.reset(M * N * L); + block_D.reset(M * N * L); + block_ref_D.reset(M * N * L); + + initialize_block(block_A, seed + 2023); + initialize_block(block_B, seed + 2022); + initialize_block(block_C, seed + 2021); + } + + bool run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) { + ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l}; + + initialize(problem_size); + + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + {block_A.get(), stride_A, block_B.get(), stride_B}, + {{}, // epilogue.thread + block_C.get(), stride_C, block_D.get(), stride_D}, + hw_info + }; + + // Custom EVT fusions will have nested unnamed args, the structure of which + // can be deduced from the type definition of the EVT. + // Each node's arguments has the recursive structure of + // {first_child_args, ..., last_child_args, op_args}, + // For more complex examples of EVT initialization please refer to + // include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp + if constexpr (UseCustomEVT) { + arguments.epilogue.thread = + { // ternary op : beta * C + (alpha * acc) + {{options.beta}}, // leaf op+args : beta + {}, // leaf op+args : C + { // binary op : alpha * acc + {{options.alpha}}, // leaf op+args : alpha + {}, // leaf op+args : acc + {} // binary args : multiplies + }, // end binary op + {} // ternary args : multiply_add + }; // end ternary op + } + // Pre-defined fusions will have flat, named args for user-friendlyness + else { + arguments.epilogue.thread.alpha = options.alpha; + arguments.epilogue.thread.beta = options.beta; + } + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = gemm_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + std::cerr << "This kernel is not supported. Last CUDA error is: " + << cudaGetErrorString(cudaGetLastError()) << std::endl; + return false; + } + + status = gemm_op.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Failed to initialize the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(cudaGetLastError()) << std::endl; + return false; + } + + // Run the GEMM + status = gemm_op.run(); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(cudaGetLastError()) << std::endl; + return false; + } + + cudaError_t result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(result) << std::endl; + return false; + } + + // Verify that the result is correct + bool passed = verify(problem_size, options.alpha, options.beta); + if (!passed) { + std::cerr << "Reference check failed" << std::endl; + } + + return passed; + } + +}; + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to print a description of the example run and its result +void print_result(const std::string& description, bool passed) { + std::cout << description << ": " << (passed ? "Passed" : "Failed") << std::endl; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + cudaDeviceProp props; + + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if (__CUDACC_VER_MAJOR__ < 12 || props.major < 9) { + std::cout + << "This example requires a GPU of NVIDIA's Hopper Architecture or " + << "later (compute capability 90 or greater) and CUDA 12.0 or greater.\n"; + return 0; + } + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + + // + // Run examples + // + + // The KernelHardwareInfo struct holds the number of SMs on the GPU with a given device ID. This + // information is used by the underlying kernel. + cutlass::KernelHardwareInfo hw_info; + + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + bool passed; + + // This first example constructs a GEMM using the default schedule and stage count provided by + // the CollectiveBuilder. The scheduling policy that is expected to be most performant will be + // selected and the maximum number of stages that can fit in shared memory will be selected. + // + // This example is equivalent to declaring + // ExampleRunner< + // cutlass::gemm::collective::KernelScheduleAuto, + // cutlass::epilogue::collective::EpilogueScheduleAuto, + // cutlass::gemm::collective::StageCountAuto> + // Each of the `Auto` types indicate that the CollectiveBuilder should determine the scheduling policy and + // stage count. Note that the behavior of the CollectiveBuilder with `Auto` parameters is subject to change + // -- do not rely on `Auto` if you require a specific scheduling policy. + // If you opt in to a non-'Auto' schedule, make sure all collectives are built using specific, compatible schedules. + ExampleRunner<> auto_schedule_auto_stage_runner; + passed = auto_schedule_auto_stage_runner.run(options, hw_info); + print_result("Automatically-selected schedule and stage count", passed); + + // One can override the stage count used in the GEMM by replacing cutlass::gemm::collective::StageCountAuto + // with the number of stages to use (5 in this case). + ExampleRunner< + cutlass::gemm::collective::KernelScheduleAuto, + cutlass::epilogue::collective::EpilogueScheduleAuto, + _5> auto_schedule_5_stage_runner; + + passed = auto_schedule_5_stage_runner.run(options, hw_info); + print_result("Automatically-selected schedule with 5 stages", passed); + + // One can also override the scheduling policy to use. In this case, use the KernelTma scheduling + // policy, which specifies that the Hopper TMA feature should be used, and we also use an epilogue + // that does not use any shared memory. + ExampleRunner tma_schedule_auto_stage_runner; + passed = tma_schedule_auto_stage_runner.run(options, hw_info); + print_result("TMA schedule with automatically-selected stage count", passed); + + // Here, we override the scheduling policy to use Hopper's TMA feature alongside the warp-specialized + // scheduling policy, and an epilogue that does not use any shared memory. + ExampleRunner ws_schedule_auto_stage_runner; + passed = ws_schedule_auto_stage_runner.run(options, hw_info); + print_result("Warp-specialized TMA schedule with automatically-selected stage count", passed); + + // Here, we override the scheduling policy to use Hopper's TMA feature, alongside the warp-specialized + // scheduling policy, TMA-based epilogue, leveraging persistent thread blocks. + ExampleRunner< + cutlass::gemm::KernelTmaWarpSpecializedPingpong, + cutlass::epilogue::TmaWarpSpecialized> ws_pingpong_schedule_auto_stage_runner; + passed = ws_pingpong_schedule_auto_stage_runner.run(options, hw_info); + print_result("Ping-pong warp-specialized TMA schedule with automatically-selected stage count", passed); + + // Here, we override the scheduling policy to use stream-K problem decomposition atop the cooperative + // warp-specialized scheduling policy. This kernel continues to leverage persistent thread blocks + // as well aso TMA in both the mainloop and epilogue. + ExampleRunner< + cutlass::gemm::KernelTmaWarpSpecializedCooperative, + cutlass::epilogue::TmaWarpSpecializedCooperative, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::StreamKScheduler> ws_cooperative_stream_k_schedule_auto_stage_runner; + passed = ws_cooperative_stream_k_schedule_auto_stage_runner.run(options, hw_info); + print_result("Cooperative warp-specialized TMA schedule using stream-K with automatically-selected stage count", passed); + + // Here, we override the fusion operation to use a customized EVT fusion, in addition to the previous schedule overrides + ExampleRunner< + cutlass::gemm::KernelTmaWarpSpecializedCooperative, + cutlass::epilogue::TmaWarpSpecializedCooperative, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::PersistentScheduler, + true> ws_cooperative_schedule_auto_stage_custom_evt_runner; + passed = ws_cooperative_schedule_auto_stage_custom_evt_runner.run(options, hw_info); + print_result("Cooperative warp-specialized TMA schedule using custom epilogue visitor tree with automatically-selected stage count", passed); + +#endif + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/49_hopper_gemm_with_collective_builder/CMakeLists.txt b/examples/49_hopper_gemm_with_collective_builder/CMakeLists.txt new file mode 100644 index 0000000000..4925105d75 --- /dev/null +++ b/examples/49_hopper_gemm_with_collective_builder/CMakeLists.txt @@ -0,0 +1,34 @@ + +# Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# Both filenames are shorter to avoid MAX_PATH issues on Windows. +cutlass_example_add_executable( + 49_collective_builder + 49_collective_builder.cu + ) diff --git a/examples/50_hopper_gemm_with_epilogue_swizzle/50_hopper_gemm_with_epilogue_swizzle.cu b/examples/50_hopper_gemm_with_epilogue_swizzle/50_hopper_gemm_with_epilogue_swizzle.cu new file mode 100644 index 0000000000..a736e5ce31 --- /dev/null +++ b/examples/50_hopper_gemm_with_epilogue_swizzle/50_hopper_gemm_with_epilogue_swizzle.cu @@ -0,0 +1,526 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Hopper GEMM example to create a GEMM kernel with custom Collectives + + The following example shows how to assemble a custom GEMM kernel that spells out the Collectives + directly instead of using a builder and, in the process, instance a more efficient Epilogue + (from `cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp`) instead of using the default epilogue. + + The GemmUniversal API takes 3 main template arguments: + (1) the problem shape / extents + (2) the collective mainloop type + (3) the collective epilogue type + + While the collecive mainloop can be stamped out using a CollectiveBuilder interface, it is + possible to build a custom collective mainloop directly as well. Furthermore, since epilogues + do not yet have a builder interface, this example shows how to instantiate a more-efficient + epilogue alongside the collective mainloop. + + Note: there are several ways to implement the GEMM epilogue in Hopper - each with its own set + of trade-offs. So it is recommended that users look at the options available under + cutlass/epilogue/collective and evaluate for their particular scenario. + + Please refer to examples 48, 49 to learn more about kernel schedules and other CuTe examples + present in `test/unit/cute` to famialiarize with the basics of CuTe. + + Examples: + + $ ./examples/50_hopper_gemm_with_epilogue_swizzle/50_hopper_gemm_with_epilogue_swizzle +*/ + +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/util/command_line.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/collective_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/device/tensor_fill.h" + +using namespace cute; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + bool error; + + int m, n, k, l; + int alpha, beta; + + Options(): + help(false), + error(false), + m(2048), n(2048), k(2048), l(1), + alpha(1), beta(0) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m, 2048); + cmd.get_cmd_line_argument("n", n, 2048); + cmd.get_cmd_line_argument("k", k, 2048); + cmd.get_cmd_line_argument("l", l, 1); + cmd.get_cmd_line_argument("alpha", alpha, 1); + cmd.get_cmd_line_argument("beta", beta, 0); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "50_hopper_gemm_with_epilogue_swizzle\n\n" + << "Hopper GEMM Example with Epilogue Swizzle.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the L extent (batch count) of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n"; + + return out; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_block( + cutlass::DeviceAllocation& block, + uint64_t seed=2023) { + + Element scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, scope_max, scope_min, 0); + + return true; +} + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +// Wrapper to run and verify a GEMM. +template < + class Gemm +> +struct ExampleRunner { + + using StrideA = typename Gemm::GemmKernel::StrideA; + using StrideB = typename Gemm::GemmKernel::StrideB; + using StrideC = typename Gemm::GemmKernel::StrideC; + using StrideD = typename Gemm::GemmKernel::StrideD; + + using LayoutA = typename Gemm::LayoutA; + using LayoutB = typename Gemm::LayoutB; + using LayoutC = typename Gemm::LayoutC; + using LayoutD = typename Gemm::LayoutD; + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementAcc = typename Gemm::ElementAccumulator; + + using CollectiveEpilogue = typename Gemm::CollectiveEpilogue; + using ElementC = typename Gemm::ElementC; + using ElementOutput = typename CollectiveEpilogue::ElementOutput; + using ElementCompute = typename CollectiveEpilogue::ElementCompute; + using ElementAccumulator = typename CollectiveEpilogue::ElementAccumulator; + + using ProblemShapeType = typename Gemm::GemmKernel::ProblemShape; + + // + // Data members + // + + /// Initialization + StrideA stride_A; + StrideB stride_B; + StrideC stride_C; + StrideD stride_D; + uint64_t seed = 0; + + cutlass::DeviceAllocation block_A; + cutlass::DeviceAllocation block_B; + cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_D; + cutlass::DeviceAllocation block_ref_D; + + // + // Methods + // + + bool verify(const ProblemShapeType& problem_size, int32_t alpha, int32_t beta) { + auto [M, N, K, L] = problem_size; + + cutlass::TensorRef ref_A(block_A.get(), LayoutA::packed({M, K})); + cutlass::TensorRef ref_B(block_B.get(), LayoutB::packed({K, N})); + cutlass::TensorRef ref_C(block_C.get(), LayoutC::packed({M, N})); + cutlass::TensorRef ref_D(block_ref_D.get(), LayoutD::packed({M, N})); + + cutlass::reference::device::GemmComplex( + {M, N, K}, + ElementCompute(alpha), + ref_A, + cutlass::ComplexTransform::kNone, + ref_B, + cutlass::ComplexTransform::kNone, + ElementCompute(beta), + ref_C, + ref_D, + ElementAccumulator(0), + L, // batch_count + M * K, // batch_stride_A + K * N, // batch_stride_B + M * N, // batch_stride_C + M * N // batch_stride_D + ); + + cudaError_t result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + std::cerr << "Reference kernel failed. Last CUDA error: " + << cudaGetErrorString(result) << std::endl; + return false; + } + + // Check if output from CUTLASS kernel and reference kernel are equal or not + bool passed = cutlass::reference::device::BlockCompareEqual(block_ref_D.get(), block_D.get(), block_D.size()); + + return passed; + } + + /// Initialize operands to be used in the GEMM and reference GEMM + void initialize(const ProblemShapeType& problem_size) { + auto problem_shape_MNKL = cute::append<4>(problem_size, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + + block_A.reset(M * K * L); + block_B.reset(K * N * L); + block_C.reset(M * N * L); + block_D.reset(M * N * L); + block_ref_D.reset(M * N * L); + + initialize_block(block_A, seed + 2023); + initialize_block(block_B, seed + 2022); + initialize_block(block_C, seed + 2021); + } + + bool run(const Options& options, const cutlass::KernelHardwareInfo& hw_info) { + ProblemShapeType problem_size = ProblemShapeType{options.m, options.n, options.k, options.l}; + + initialize(problem_size); + + typename Gemm::GemmKernel::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_size, + {block_A.get(), stride_A, block_B.get(), stride_B}, + {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D}, + hw_info + }; + + Gemm gemm_op; + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + cutlass::Status status = gemm_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + std::cerr << "This kernel is not supported. Last CUDA error is: " + << cudaGetErrorString(cudaGetLastError()) << std::endl; + return false; + } + + status = gemm_op.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Failed to initialize the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(cudaGetLastError()) << std::endl; + return false; + } + + // Run the GEMM + status = gemm_op.run(); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Failed to launch the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(cudaGetLastError()) << std::endl; + return false; + } + + cudaError_t result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + std::cerr << "Error running the CUTLASS kernel. Last CUDA error is: " + << cudaGetErrorString(result) << std::endl; + return false; + } + + // Verify that the result is correct + bool passed = verify(problem_size, options.alpha, options.beta); + if (!passed) { + std::cerr << "Reference check failed" << std::endl; + } + + return passed; + } + +}; + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + cudaDeviceProp props; + + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if (__CUDACC_VER_MAJOR__ < 12 || props.major < 9) { + std::cout + << "This example requires a GPU of NVIDIA's Hopper Architecture or " + << "later (compute capability 90 or greater) and CUDA 12.0 or greater.\n"; + return 0; + } + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + + // + // Run examples + // + + // The KernelHardwareInfo struct holds the number of SMs on the GPU with a given device ID. This + // information is used by the underlying kernel. + cutlass::KernelHardwareInfo hw_info; + + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + bool passed; + + // Problem configuration + using ElementA = int8_t; + using ElementB = int8_t; + using ElementAcc = int32_t; + using ElementOutput = int8_t; + + // Note : Only TN WGMMA Gemm is supported currently in 3.0 + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::ColumnMajor; + using LayoutD = cutlass::layout::ColumnMajor; + + // Tiling configuration selection + using TileShape = Shape<_128,_64,_128>; + + // Choosing a thread block cluster larger than 1 allows us to Multicast data across thread blocks + using ClusterShape = Shape<_1,_2,_1>; + + // + // Assembling the CollectiveMainloop type + // + + // Pipeline Depth to be used i.e number of A, B buffers in shared memory + constexpr int PipelineStages = 8; + + // Let's choose a Warp-Specialized Mainloop implemention which uses TMA + // Note : This requires / assumes the tensors to be 16B aligned + using DispatchPolicy = cutlass::gemm::MainloopSm90TmaGmmaWarpSpecialized; + + // TN => K Major for both A & B + static constexpr cute::GMMA::Major GmmaMajorA = cute::GMMA::Major::K; + static constexpr cute::GMMA::Major GmmaMajorB = cute::GMMA::Major::K; + + // We use the SS op selector as both A, B operands are read directly from SMEM (for TN WGMMA) + using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector< + ElementA, ElementB, ElementAcc, TileShape, GmmaMajorA, GmmaMajorB>())); + + // A loads can be optimized with multicast if cluster-n > 1 + using GmemTiledCopyA = std::conditional< cute::size(shape<1>(ClusterShape{})) == 1, + cute::SM90_TMA_LOAD, + cute::SM90_TMA_LOAD_MULTICAST>::type; + + // B loads can be optimized with multicast if cluster-m > 1 + using GmemTiledCopyB = std::conditional< cute::size(shape<0>(ClusterShape{})) == 1, + cute::SM90_TMA_LOAD, + cute::SM90_TMA_LOAD_MULTICAST>::type; + + using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::ss_smem_selector< + GmmaMajorA, ElementA, decltype(cute::get<0>(TileShape{})), decltype(cute::get<2>(TileShape{})) + >()); + + using SmemLayoutAtomB = decltype(cutlass::gemm::collective::detail::ss_smem_selector< + GmmaMajorB, ElementB, decltype(cute::get<1>(TileShape{})), decltype(cute::get<2>(TileShape{})) + >()); + + using CollectiveMainloop = cutlass::gemm::collective::CollectiveMma< + DispatchPolicy, + TileShape, + ElementA, + cutlass::gemm::TagToStrideA_t, + ElementB, + cutlass::gemm::TagToStrideB_t, + TiledMma, + GmemTiledCopyA, + SmemLayoutAtomA, + void, // Does not need a SmemCopyAtom, since A is read directly from SMEM + cute::identity, + GmemTiledCopyB, + SmemLayoutAtomB, + void, // Does not need a SmemCopyAtom, since B is read directly from SMEM + cute::identity + >; + + // + // Assembling the Collective Epilogue Type + // + + // Break the 128 along TILE_M into chunks of 32, to get a 128B leading dimension + using PreSwizzleLayout = Layout< Shape< Shape <_32,_4 >,_64>, + Stride,_32>>; + + // 128 threads loading 16 elements each (to get vectorized global stores) + using TileShapeS2R = Shape<_128,_16>; + + // Layout to ensure bank-conflict free loads & stores + using SmemLayout = ComposedLayout< + Swizzle<3,4,3>, + smem_ptr_flag_bits::value>, + PreSwizzleLayout>; + + // Tiled copy from Smem to Registers + // Note : CuTe will vectorize this copy if the tiling + swizzling above were right + using TiledCopyS2R = TiledCopy< + Copy_Atom, + Layout< Shape<_128,_16>, + Stride<_16,_1>>, + TileShapeS2R>; + + using Epilogue = cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter< + cutlass::epilogue::collective::Epilogue< + cutlass::gemm::TagToStrideC_t, + cutlass::gemm::TagToStrideC_t, + cutlass::epilogue::thread::LinearCombination, + SmemLayout, + Copy_Atom, + TiledCopyS2R, + Copy_Atom>>; + + // + // Assembling the GemmKernel + // + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, + CollectiveMainloop, + Epilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + ExampleRunner runner; + + passed = runner.run(options, hw_info); + + std::cout << "WGMMA GEMM with Epilogue Swizzle : " << (passed ? "Passed" : "Failed") << std::endl; + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/50_hopper_gemm_with_epilogue_swizzle/CMakeLists.txt b/examples/50_hopper_gemm_with_epilogue_swizzle/CMakeLists.txt new file mode 100644 index 0000000000..5498d4effe --- /dev/null +++ b/examples/50_hopper_gemm_with_epilogue_swizzle/CMakeLists.txt @@ -0,0 +1,35 @@ + +# Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + + +cutlass_example_add_executable( + 50_hopper_gemm_with_epilogue_swizzle + 50_hopper_gemm_with_epilogue_swizzle.cu + ) diff --git a/examples/51_hopper_gett/51_hopper_gett.cu b/examples/51_hopper_gett/51_hopper_gett.cu new file mode 100644 index 0000000000..005eaec5a1 --- /dev/null +++ b/examples/51_hopper_gett/51_hopper_gett.cu @@ -0,0 +1,371 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Example of a GETT targeting Hopper tensor cores using the CUTLASS 3.x API. + + CUTLASS has long provided implementations of Generalized Matrix times Matrix (GEMM) kernels. + However, a plethora of workloads compute on higher ranked tensors. Products of such tensors, + called tensor contractions, can be executed as multiple batched GEMMs, however, they can be + further accelerated with kernels that natively operate on these higher ranked tensors to + perform Generalized Tensor times Tensor contractions (GETT). CuTe's hierarchical layouts + and CUTLASS 3.0's unified micro-kernels make implementation of GETTs trivial. In this example, + we show how CUTLASS 3.0, CuTe, and Hopper's TMA feature together can accelerate GETTs while + making the process of authoring custom GETT kernels easier than ever before. + + The modes of a tensor that participate in a GETT can be fundamentally grouped into four + semantic categories. The contraction modes (or K-modes) only appear in the A and B (left and right) + inputs but not in the C output tensor. Row modes (or M-modes) only appear in the left + input tensor (A) and the output tensor (C). Column modes (or N-modes) only appear in the + right (B) input tensor and the output tensor (C). Batch modes (or L-modes) appear in all + input and output tensors. If we fold the many modes of a tensor contraction into these four + categories, it would allow us to represent the input and output tensors as rank-3 "matrices" + that can be computed upon as if we were computing a batched GEMM! + + This is exactly what CuTe's hierarchical layout representation allows us to do! Instead of having + simple integers as strides for these four modes, we can have nested strides for each of these + semantic categories that themselves have multiple modes within them -- multi-mode strides! + In CUTLASS 3.0, all one has to do to take advantage of this capability is to substitute the + required multi-mode strides instead of the default ones provided by gemm::detail::TagToStrideX. + + In the following example, we illustrate how every Hopper GEMM in CUTLASS 3.0 is a GETT in disguise. + We begin by defining the four modes detailed above as Row, Col (column), Red (reduction), and + Bat (batch) strides, which we then nest for each of the in/out tensors to create our rank-3 stride + tuples. Note that although we do not define the problem shape type explicitely, it too remains a + rank-4 shape tuple just like any other batched GEMM, but instead with multi-mode shapes for each + of the four corresponding multi-modes within it. After this, the same CollectiveMma and + CollectiveBuilder we describe in examples 50 and 49 are used to create our kernel type. Nothing + else changes from a user's point of view. Note that multi-mode strides do not affect our + specializations in any way -- the lexical spelling of our kernels remains the same. The + only difference between a CUTLASS 3 batched GEMM and GETT are the instaced CuTe Layouts. + + CollectiveBuilders rely on detecting the static-1 in the stride tuples to determine the major mode, + which is what the example demonstrates. However, it is possible to have all modes be dynamic as well + if the user assembles a CollectiveMma manually and ensures that the runtime strides are compatible + with the static micro-kernel of the collective (TiledMma, TiledCopy, and smem layouts). On the other + hand, a user can have more than one static stride too (which need not correspond to the major mode). + + In particular, this example demonstrates a GETT where the 0th M-mode (M0) in A and the 0th K-mode (K0) + in B are major. All other combinations of major modes are supported, with the exception of mixed + K-major scenarios where both A and B are K-major (e.g. K0 is major in A but K1 is major in B). + NVIDIA Hopper architecture's TMA feature makes the predictaion required to implement these complicated + kernels trivial, as it is all handled by TMA itself without requiring any programmer effort. + + Example executions, where the stride order defines the major-order (major on the left): + 51_hopper_gett --modeC=m,n,l --modeA=m,k,l --modeB=k,n,l --extents=m:4096,n:4096,k:4096 + 51_hopper_gett --modeC=l,m,n --modeA=m,l,k --modeB=k,n,l --extents=m:128,n:128,k:128,l:64 + 51_hopper_gett --modeC=m,a,b,p,q,n,l --modeA=m,l,b,k,a --modeB=k,n,p,q,l --extents=m:32,a:32,b:3,n:128,k:128,l:4,p:3,q:3 +*/ + +#include "gett_kernel.cuh" +#include "thrust/host_vector.h" +#include "thrust/device_vector.h" + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" + +#include "cutlass/util/gett_commandline.hpp" +#include "cutlass/util/reference/device/gett.hpp" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/print_error.hpp" + +namespace example { + +// Returns true if the left-most value in the tuple is statically known to be 1 +template +constexpr bool +is_left_major() { + // Account for stride types with and without batch mode and batch modes with static zero stride + return cute::is_constant<1, decltype(cute::size<0,0>(Stride{}))>::value; +} + +// Same as cute::make_int_tuple but inserts a major stride (Int<1>) for the leftmost mode if required +template +static constexpr +auto +make_stride_tuple(Indexable const& t, int n, int64_t init_default = 0) { + static_assert(Rank > 1); + if constexpr (IsMajor) { + return cute::transform(cute::make_seq{}, [&](auto i) { + if constexpr (i == 0) { + return cute::Int<1>{}; + } + else { + return i < n ? t[i] : init_default; + } + }); + } + else { + return cute::make_int_tuple(t, n, init_default); + } +} + +} // namespace example + +////////////////////////////////////////////////////////////////////////////// + +int +main(int argc, char const* argv[]) { +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + using namespace cute; + + if (argc != 5) { + std::cout << "Number of command line args must be 4.\n"; + cutlass::GettCommandLine::print_usage(); + return 0; + } + + // + // Define the stride types for A, B, C, and D + // + + // Stride for A (left input). If reduction mode is major, same must be major in B + // For this example, M0 is major in A. + using RowModeStridesA = cute::Stride, int64_t, int64_t, int64_t>; + using RedModeStridesA = cute::Stride; + using BatModeStridesA = cute::Stride; + + // Stride for B (right input). If reduction mode is major, same must be major in A + // For this example, K0 is major in B. + using ColModeStridesB = cute::Stride; + using RedModeStridesB = cute::Stride, int64_t, int64_t>; + using BatModeStridesB = cute::Stride; + + // Strides for output, which can all be dynamic. + using RowModeStridesC = cute::Stride; + using ColModeStridesC = cute::Stride; + using BatModeStridesC = cute::Stride; + + // Assmble our rank-3 multi-mode strides for the in/out tensors + using StrideA = cute::Stride; + using StrideB = cute::Stride; + using StrideC = cute::Stride; + + // Note: C and D share strides here for simplicity. + // In general, they need not have the same layout. + using StrideD = StrideC; + + // + // Define element types for tensors and intermediate values + // + using ElementA = cutlass::half_t; + using ElementB = cutlass::half_t; + using ElementC = cutlass::half_t; + using ElementD = float; + using ElementAccumulator = float; + using ElementEpilogue = float; + + // The following constexpr values set the max number of modes in each MNKL mode + constexpr int MaxRank_M = cute::rank(RowModeStridesA{}); // Max row modes + constexpr int MaxRank_N = cute::rank(ColModeStridesB{}); // Max column modes + constexpr int MaxRank_K = cute::rank(RedModeStridesA{}); // Max contraction modes + constexpr int MaxRank_L = cute::rank(BatModeStridesA{}); // Max batch modes + static_assert(cute::rank(RowModeStridesA{}) == cute::rank(RowModeStridesC{})); + static_assert(cute::rank(ColModeStridesB{}) == cute::rank(RowModeStridesC{})); + static_assert(cute::rank(RedModeStridesA{}) == cute::rank(RedModeStridesB{})); + static_assert(cute::rank(BatModeStridesA{}) == cute::rank(BatModeStridesC{})); + static_assert(cute::rank(BatModeStridesB{}) == cute::rank(BatModeStridesC{})); + + // Parse command line to get modes, extents, and strides + cutlass::GettCommandLine cmd; + auto parsed_args = cmd.parse(argc, argv, true); + + auto& m = parsed_args.M; + auto& ldAm = parsed_args.ldAm; + auto& ldCm = parsed_args.ldCm; + int rank_m = int(m.size()); + + auto& n = parsed_args.N; + auto& ldBn = parsed_args.ldBn; + auto& ldCn = parsed_args.ldCn; + int rank_n = int(n.size()); + + auto& k = parsed_args.K; + auto& ldAk = parsed_args.ldAk; + auto& ldBk = parsed_args.ldBk; + int rank_k = int(k.size()); + + auto& l = parsed_args.L; + auto& ldAl = parsed_args.ldAl; + auto& ldBl = parsed_args.ldBl; + auto& ldCl = parsed_args.ldCl; + int rank_l = int(l.size()); + + if ((rank_m > MaxRank_M) || (rank_n > MaxRank_N) || (rank_k > MaxRank_K) || (rank_l > MaxRank_L)) { + std::cerr << "ERROR: Input has more modes than statically configured."; + return 1; + } + + // Check that the user input major stride match the static major strides. + if (example::is_left_major() && (ldAm[0] != 1)) { + std::cerr << "ERROR: A_M0 is expected to be major, but was not in the provided input!\n"; + return 1; + } + + if (example::is_left_major() && (ldAk[0] != 1)) { + std::cerr << "ERROR: A_K0 is expected to be major, but was not in the provided input!\n"; + return 1; + } + + if (example::is_left_major() && (ldBn[0] != 1)) { + std::cerr << "ERROR: B_N0 is expected to be major, but was not in the provided input!\n"; + return 1; + } + + if (example::is_left_major() && (ldBk[0] != 1)) { + std::cerr << "ERROR: B_K0 is expected to be major, but was not in the provided input!\n"; + return 1; + } + + // Convert to `cute::Tuple`s and set up arguments + auto M = make_int_tuple(m.data(), rank_m, 1); + auto dAm = example::make_stride_tuple()>(ldAm.data(), rank_m); + auto dCm = example::make_stride_tuple()>(ldCm.data(), rank_m); + + auto N = make_int_tuple(n.data(), rank_n, 1); + auto dBn = example::make_stride_tuple()>(ldBn.data(), rank_n); + auto dCn = example::make_stride_tuple()>(ldCn.data(), rank_n); + + auto K = make_int_tuple(k.data(), rank_k, 1); + auto dAk = example::make_stride_tuple()>(ldAk.data(), rank_k); + auto dBk = example::make_stride_tuple()>(ldBk.data(), rank_k); + + auto L = make_int_tuple(l.data(), rank_l, 1); + auto dAl = make_int_tuple(ldAl.data(), rank_l, 0); + auto dBl = make_int_tuple(ldBl.data(), rank_l, 0); + auto dCl = make_int_tuple(ldCl.data(), rank_l, 0); + + // Concat tuples to turn it into rank-4 problem shape and rank-3 strides, just like GEMM + auto problem_shape = make_shape(M, N, K, L); + StrideA stride_A = make_stride(dAm, dAk, dAl); + StrideB stride_B = make_stride(dBn, dBk, dBl); + StrideC stride_C = make_stride(dCm, dCn, dCl); + StrideD stride_D = stride_C; + + auto alpha = ElementEpilogue(1.0f); + auto beta = ElementEpilogue(1.0f); + + // + // Allocate and init tensors + // + auto M_size = std::accumulate(std::begin(m), std::end(m), 1, std::multiplies<>{}); + auto N_size = std::accumulate(std::begin(n), std::end(n), 1, std::multiplies<>{}); + auto K_size = std::accumulate(std::begin(k), std::end(k), 1, std::multiplies<>{}); + auto L_size = std::accumulate(std::begin(l), std::end(l), 1, std::multiplies<>{}); + + thrust::host_vector h_A(M_size * K_size * L_size); + thrust::host_vector h_B(N_size * K_size * L_size); + thrust::host_vector h_C(M_size * N_size * L_size); + thrust::host_vector h_D(M_size * N_size * L_size); + + // Note: the cast to int here is to avoid false-negative ref-checks which can + // occur due to floating point arithmetic not being purely associative. + for (auto& a : h_A) a = ElementA(int(4*(rand() / double(RAND_MAX)) - 1)); + for (auto& b : h_B) b = ElementB(int(4*(rand() / double(RAND_MAX)) - 1)); + for (auto& c : h_C) c = ElementC(int(4*(rand() / double(RAND_MAX)) - 1)); + for (auto& d : h_D) d = ElementD(-1); + + thrust::device_vector d_A = h_A; + thrust::device_vector d_B = h_B; + thrust::device_vector d_C = h_C; + thrust::device_vector cutlass_result = h_D; + thrust::device_vector reference_result = h_D; + + // + // Compute GETT + // + auto status = example::gett_kernel( + problem_shape, + d_A.data().get(), stride_A, + d_B.data().get(), stride_B, + ElementAccumulator{}, + d_C.data().get(), stride_C, + cutlass_result.data().get(), stride_D, + alpha, beta); + + if (cutlass::Status::kSuccess != status) { + std::cerr << "ERROR: GETT operator launch failed.\n"; + return 1; + } + + auto cuda_err = cudaDeviceSynchronize(); + if (cudaSuccess != cuda_err) { + std::cerr << "ERROR: GETT operator execution failed. with error :"; + std::cerr << cudaGetErrorString(cuda_err) << "\n"; + return 1; + } + + // + // Verify + // + + cutlass::reference::device::gett( + problem_shape, + d_A.data().get(), stride_A, + d_B.data().get(), stride_B, + ElementAccumulator{}, + d_C.data().get(), stride_C, + reference_result.data().get(), stride_D, + alpha, beta); + + cuda_err = cudaDeviceSynchronize(); + if (cudaSuccess != cuda_err) { + std::cerr << "ERROR: GETT reference execution failed. with error :"; + std::cerr << cudaGetErrorString(cuda_err) << "\n"; + return 1; + } + + // Check if output from CUTLASS kernel and reference kernel are equal or not + bool passed = cutlass::reference::device::BlockCompareEqual( + reference_result.data().get(), cutlass_result.data().get(), cutlass_result.size()); + if (passed) { + std::cout << "GETT verification passed.\n"; + return 0; + } + else { + std::cerr << "ERROR: GETT verification failed! Printing detailed stats.\n"; + h_D = reference_result; + thrust::host_vector h_cutlass_result = cutlass_result; + print_relative_error(h_cutlass_result.size(), h_cutlass_result.data(), h_D.data()); + + std::cout << "StrideA: "; print(stride_A); std::cout << '\n'; + std::cout << "StrideB: "; print(stride_B); std::cout << '\n'; + std::cout << "StrideC: "; print(stride_C); std::cout << '\n'; + std::cout << "StrideD: "; print(stride_D); std::cout << '\n'; + return 1; + } +#else + std::cerr << "Unsupported example. Please ensure CUTLASS_ARCH_MMA_SM90_SUPPORTED is defined.\n"; + return 0; +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) +} diff --git a/examples/51_hopper_gett/CMakeLists.txt b/examples/51_hopper_gett/CMakeLists.txt new file mode 100644 index 0000000000..f18dff3817 --- /dev/null +++ b/examples/51_hopper_gett/CMakeLists.txt @@ -0,0 +1,32 @@ +# Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +cutlass_example_add_executable( + 51_hopper_gett + 51_hopper_gett.cu +) diff --git a/examples/51_hopper_gett/gett_kernel.cuh b/examples/51_hopper_gett/gett_kernel.cuh new file mode 100644 index 0000000000..6a775d137b --- /dev/null +++ b/examples/51_hopper_gett/gett_kernel.cuh @@ -0,0 +1,138 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/tensor.hpp" + +#include "cutlass/arch/arch.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "cutlass/epilogue/collective/collective_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" + +namespace example { + +// +// GETT entry point +// +template < + class ProblemShapeMNKL, + class ElementA, + class StrideA, + class ElementB, + class StrideB, + class ElementAccumulator, + class ElementC, + class StrideC, + class ElementD, + class StrideD, + class ElementEpilogue> +cutlass::Status +gett_kernel( + ProblemShapeMNKL problem_shape_mnkl, + ElementA const* ptr_A, StrideA stride_a_mkl, + ElementB const* ptr_B, StrideB stride_b_nkl, + ElementAccumulator _, + ElementC const* ptr_C, StrideC stride_c_mnl, + ElementD * ptr_D, StrideD stride_d_mnl, + ElementEpilogue alpha, ElementEpilogue beta, + cudaStream_t stream = 0) { + using namespace cute; + + // TileShape -- GETT configuration + // Specify the number of elements to take from each mode + // BLK_M = (M0,M1,...) BLK_N = (M0,M1,...) BLK_K = (K0,K1,...) + + // Take 128 from m0, 128 from n0, 64 from k0 + using TileShape = Shape, Shape<_128>, Shape<_64>>; + + /* Other examples: + * Take 32 elements from m0 and 4 elements from m1 + * Take 64 elements from n0 and 2 elements from n1 + * Take 8 elements from k0 and 8 elements from k1 + **/ + // using TileShape = Shape, Shape<_64,_2>, Shape<_8,_8>>; + + using EpilogueThreadOp = cutlass::epilogue::thread::LinearCombination< + ElementD, 1, ElementAccumulator, ElementEpilogue, cutlass::epilogue::thread::ScaleType::Default, + cutlass::FloatRoundStyle::round_to_nearest, ElementC>; + + // No changes are required to the default epilogue + using CollectiveEpilogue = cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter< + cutlass::epilogue::collective::DefaultEpilogue< + StrideC, + StrideD, + EpilogueThreadOp, + cutlass::gemm::EpilogueDefault>>; + + // CollectiveMma for GETTs can be built using the CollectiveBuilders + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, StrideA, 128 / cutlass::sizeof_bits::value, + ElementB, StrideB, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + TileShape, Shape<_1,_2,_1>, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + // The GETT kernel is a composition of a collective mainloop and epilogue, just like any 3.x GEMM + using GettKernel = cutlass::gemm::kernel::GemmUniversal< + ProblemShapeMNKL, + CollectiveMainloop, + CollectiveEpilogue>; + + using GettOperator = cutlass::gemm::device::GemmUniversalAdapter; + + typename GettOperator::Arguments args { + cutlass::gemm::GemmUniversalMode::kBatched, + problem_shape_mnkl, + { ptr_A, stride_a_mkl, ptr_B, stride_b_nkl }, + { {alpha, beta}, ptr_C, stride_c_mnl, ptr_D, stride_d_mnl } + }; + +#if CUTLASS_DEBUG_TRACE_LEVEL > 0 + print("Problem shape:"); + print("\tM: "); print(cute::get<0>(problem_shape_mnkl)); print("\n"); + print("\tN: "); print(cute::get<1>(problem_shape_mnkl)); print("\n"); + print("\tK: "); print(cute::get<2>(problem_shape_mnkl)); print("\n"); + print("\tL: "); print(cute::get<3>(problem_shape_mnkl)); print("\n"); + print("TileSape:"); print(TileShape{}); print("\n"); +#endif + + GettOperator op; + return op(args, stream); +} + +} // namespace example diff --git a/examples/52_hopper_gather_scatter_fusion/52_hopper_gather_scatter_fusion.cu b/examples/52_hopper_gather_scatter_fusion/52_hopper_gather_scatter_fusion.cu new file mode 100644 index 0000000000..0a74e02a83 --- /dev/null +++ b/examples/52_hopper_gather_scatter_fusion/52_hopper_gather_scatter_fusion.cu @@ -0,0 +1,693 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Example of a Hopper gather+GEMM+scatter kernel fusion. + + This example fuses gather before GEMM and scatter after GEMM into the same + GEMM kernel. Gather and scatter operation is controled by an index vector + to select rows or columns from A, B, C or D matrices. + + Gather/scatter operations are always performed along a strided dimension + in order to preserve vectorized loads/stores. Thus the index vector is + applied to rows of row-major matrices and columns of column-major matrices. + + Note that the index vector must contain integers in range [0,X) where + X is one of (M,N,K), depending on selected gather dimension. The problem + shape given to the GEMM kernel must consist of matrix sizes AFTER gather + and BEFORE scatter operations are applied. +*/ + +#include +#include +#include +#include +#include +#include +#include +#include +#include + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/device/gemm_universal.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/util/command_line.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/tensor_view_io.h" + +#include "helper.h" +#include "gather_gemm.hpp" +#include "gather_kernel.cuh" +#include "scatter_epilogue.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +using namespace cute; + +namespace example { + +// Command line options parsing +struct Options { + + bool help = false; + + cutlass::gemm::BatchedGemmCoord problem_size = {2048, 2048, 2048, 1}; + int index_size = 1024; + int mode = 1; // N-mode gather/scatter by default + + float alpha = 1.0f; + float beta = 0.0f; + + bool reference_check = true; + int iterations = 20; + + bool valid() const { + return problem_size.m() > 0 + && problem_size.n() > 0 + && problem_size.k() > 0 + && problem_size.batch() > 0 + && 0 <= mode && mode < 3 + && index_size <= problem_size.at(mode) + && iterations > 0; + } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + } + + cmd.get_cmd_line_argument("m", problem_size.m()); + cmd.get_cmd_line_argument("n", problem_size.n()); + cmd.get_cmd_line_argument("k", problem_size.k()); + cmd.get_cmd_line_argument("batch_size", problem_size.batch()); + cmd.get_cmd_line_argument("index_size", index_size); + + char const modes[] = {'m', 'n', 'k'}; + char mode_input = modes[mode]; + cmd.get_cmd_line_argument("mode", mode_input); + mode = int(std::distance(std::begin(modes), std::find(std::begin(modes), std::end(modes), mode_input))); + + cmd.get_cmd_line_argument("alpha", alpha); + cmd.get_cmd_line_argument("beta", beta); + + cmd.get_cmd_line_argument("check", reference_check, true); + cmd.get_cmd_line_argument("iterations", iterations); + + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << + "52_hopper_gather_scatter_fusion example\n" + "\n" + " This example uses the CUTLASS Library to fuse gather/scatter of input/output tensors with GEMM.\n" + " It validates and benchmarks the fused kernel against an unfused implementation that executes\n" + " gather+GEMM+scatter in sequence and writes intermediate (gathered) tensors to memory.\n" + " For the unfused implementation two GEMM kernels are considered: default one that uses the same\n" + " schedule and instruction set as the fused one, and an optimized one that utilizes advanced\n" + " features (such as TMA units) that cannot be used by the fused kernel due to hardware constraints." + "\n" + "Options:\n" + " --help If specified, displays this usage statement.\n" + " --m= GEMM M dimension\n" + " --n= GEMM N dimension\n" + " --k= GEMM K dimension\n" + " --batch_size= GEMM batch size\n" + " --index_size= Size of N dimension gather/scatter index\n" + " --mode= Gather mode (M, N, or K)\n" + " --alpha= GEMM alpha parameter\n" + " --beta= GEMM beta parameter\n" + " --iterations= Number of profiling iterations to perform.\n" + "\n" + "Examples:\n" + "\n" + "$ ./examples/52_hopper_gather_scatter_fusion/52_hopper_gather_scatter_fusion --m=1024 --n=2048 --k=1024 --mode=n --index_size=1024\n"; + + return out; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct ExampleRunner +{ + // Useful aliases + + using ProblemShape = Shape; + + using StrideA = cutlass::gemm::TagToStrideA_t; + using StrideB = cutlass::gemm::TagToStrideB_t; + using StrideC = cutlass::gemm::TagToStrideC_t; + using StrideD = cutlass::gemm::TagToStrideC_t; + + // Alias to for the epilogue type that supports gather/scatter + using Epilogue = cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter< + cutlass::epilogue::collective::EpilogueGatherScatter< + StrideC, StrideD, + cutlass::epilogue::thread::LinearCombination< + ElementD, 1, + ElementAccumulator, ElementComputeEpilogue, + cutlass::epilogue::thread::ScaleType::Default, + cutlass::FloatRoundStyle::round_to_nearest, ElementC + >, + cutlass::gemm::EpilogueDefault, + GatherC, + ScatterD + > + >; + + // Alias to for the mainloop type + using Mainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 128 / cutlass::sizeof_bits::value, + ElementB, LayoutB, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + Shape<_128,_128,_64>, + Shape<_1,_1,_1>, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::KernelCpAsyncWarpSpecialized + >::CollectiveOp; + + using Kernel = cutlass::gemm::kernel::GemmGather< + ProblemShape, + Mainloop, + Epilogue, + void, + GatherA, + GatherB + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + + static constexpr bool DoGatherA = not cutlass::platform::is_same::value; + static constexpr bool DoGatherB = not cutlass::platform::is_same::value; + static constexpr bool DoGatherC = not cutlass::platform::is_same::value; + static constexpr bool DoScatterD = not cutlass::platform::is_same::value; + + static constexpr bool GatherAonM = DoGatherA && cutlass::platform::is_same::value; + static constexpr bool GatherAonK = DoGatherA && cutlass::platform::is_same::value; + static constexpr bool GatherBonN = DoGatherB && cutlass::platform::is_same::value; + static constexpr bool GatherBonK = DoGatherB && cutlass::platform::is_same::value; + static constexpr bool GatherConM = DoGatherC && cutlass::platform::is_same::value; + static constexpr bool GatherConN = DoGatherC && cutlass::platform::is_same::value; + static constexpr bool ScatterDonM = DoScatterD && cutlass::platform::is_same::value; + static constexpr bool ScatterDonN = DoScatterD && cutlass::platform::is_same::value; + + static constexpr bool GatherModeM = GatherAonM || GatherConM || ScatterDonM; + static constexpr bool GatherModeN = GatherBonN || GatherConN || ScatterDonN; + static constexpr bool GatherModeK = GatherAonK || GatherBonK; + + static_assert( GatherModeM && !GatherModeN && !GatherModeK || + !GatherModeM && GatherModeN && !GatherModeK || + !GatherModeM && !GatherModeN && GatherModeK, + "Only one gather mode (M, N or K) is supported by example runner"); + + // Construct a reference (non-gather) GEMM kernel type + + using MainloopRef = Mainloop; + + using EpilogueRef = cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter< + cutlass::epilogue::collective::DefaultEpilogue< + StrideC, StrideD, + typename Epilogue::ThreadEpilogueOp, + typename Epilogue::EpilogueSchedule + > + >; + + using KernelRef = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + MainloopRef, + EpilogueRef, + void + >; + + using GemmRef = cutlass::gemm::device::GemmUniversalAdapter; + + // Construct an optimized reference GEMM kernel type (using TMA) + + using EpilogueOpt = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + Shape<_128,_128,_64>, + Shape<_2,_2,_1>, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementComputeEpilogue, + ElementC, LayoutC, 128 / cutlass::sizeof_bits::value, + ElementD, LayoutD, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using MainloopOpt = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutA, 128 / cutlass::sizeof_bits::value, + ElementB, LayoutB, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + Shape<_128,_128,_64>, + Shape<_2,_2,_1>, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename EpilogueOpt::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using KernelOpt = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + MainloopOpt, + EpilogueOpt, + void + >; + + using GemmOpt = cutlass::gemm::device::GemmUniversalAdapter; + + // Data members + + cutlass::gemm::BatchedGemmCoord problem_size_orig; + cutlass::gemm::BatchedGemmCoord problem_size; + ProblemShape problem_shape_orig; + ProblemShape problem_shape; + cutlass::KernelHardwareInfo hw_info; + + ElementComputeEpilogue alpha; + ElementComputeEpilogue beta; + + StrideA stride_A_orig; + StrideB stride_B_orig; + StrideC stride_C_orig; + StrideD stride_D_orig; + + StrideA stride_A; + StrideB stride_B; + StrideC stride_C; + StrideD stride_D; + + cutlass::device_memory::allocation tensor_a; + cutlass::device_memory::allocation tensor_b; + cutlass::device_memory::allocation tensor_c; + cutlass::device_memory::allocation tensor_d; + + cutlass::device_memory::allocation gather_indices; + + cutlass::device_memory::allocation tensor_a_gathered; + cutlass::device_memory::allocation tensor_b_gathered; + cutlass::device_memory::allocation tensor_c_gathered; + cutlass::device_memory::allocation tensor_d_gathered; + cutlass::device_memory::allocation tensor_d_reference; + + cutlass::gemm::GemmUniversalMode gemm_mode; + + Gemm gemm; + typename Gemm::Arguments arguments; + cutlass::device_memory::allocation workspace; + + GemmRef gemm_ref; + typename GemmRef::Arguments arguments_ref; + cutlass::device_memory::allocation workspace_ref; + + GemmOpt gemm_opt; + typename GemmOpt::Arguments arguments_opt; + cutlass::device_memory::allocation workspace_opt; + + ExampleRunner(Options const &options, cutlass::KernelHardwareInfo const &hw_info) + : problem_size_orig(options.problem_size), + problem_size(GatherModeM ? options.index_size : problem_size_orig.m(), + GatherModeN ? options.index_size : problem_size_orig.n(), + GatherModeK ? options.index_size : problem_size_orig.k(), + problem_size_orig.batch()), + problem_shape_orig(problem_size_orig.m(), problem_size_orig.n(), problem_size_orig.k(), problem_size_orig.batch()), + problem_shape(problem_size.m(), problem_size.n(), problem_size.k(), problem_size.batch()), + hw_info(hw_info), + alpha(options.alpha), + beta(options.beta), + stride_A_orig(cutlass::make_cute_packed_stride( + StrideA{}, make_shape(problem_size_orig.m(), problem_size_orig.k(), problem_size_orig.batch()))), + stride_B_orig(cutlass::make_cute_packed_stride( + StrideB{}, make_shape(problem_size_orig.n(), problem_size_orig.k(), problem_size_orig.batch()))), + stride_C_orig(cutlass::make_cute_packed_stride( + StrideC{}, make_shape(problem_size_orig.m(), problem_size_orig.n(), problem_size_orig.batch()))), + stride_D_orig(cutlass::make_cute_packed_stride( + StrideD{}, make_shape(problem_size_orig.m(), problem_size_orig.n(), problem_size_orig.batch()))), + stride_A(cutlass::make_cute_packed_stride( + StrideA{}, make_shape(problem_size.m(), problem_size.k(), problem_size.batch()))), + stride_B(cutlass::make_cute_packed_stride( + StrideB{}, make_shape(problem_size.n(), problem_size.k(), problem_size.batch()))), + stride_C(cutlass::make_cute_packed_stride( + StrideC{}, make_shape(problem_size.m(), problem_size.n(), problem_size.batch()))), + stride_D(cutlass::make_cute_packed_stride( + StrideD{}, make_shape(problem_size.m(), problem_size.n(), problem_size.batch()))), + tensor_a(problem_size_orig.m() * problem_size_orig.k() * problem_size_orig.batch()), + tensor_b(problem_size_orig.k() * problem_size_orig.n() * problem_size_orig.batch()), + tensor_c(problem_size_orig.m() * problem_size_orig.n() * problem_size_orig.batch()), + tensor_d(problem_size_orig.m() * problem_size_orig.n() * problem_size_orig.batch()), + gather_indices(options.index_size), + tensor_a_gathered(problem_size.m() * problem_size.k() * problem_size_orig.batch()), + tensor_b_gathered(problem_size.k() * problem_size.n() * problem_size_orig.batch()), + tensor_c_gathered(problem_size.m() * problem_size.n() * problem_size_orig.batch()), + tensor_d_gathered(problem_size.m() * problem_size.n() * problem_size_orig.batch()), + tensor_d_reference(problem_size_orig.m() * problem_size_orig.n() * problem_size_orig.batch()), + gemm_mode(problem_size.batch() > 1 ? cutlass::gemm::GemmUniversalMode::kBatched : cutlass::gemm::GemmUniversalMode::kGemm), + gemm(), + // When constructing arguments for gather/scatter gemm, we must pass stride arguments + // made for the original (non-gathered) problem size, because they are used to access + // tensors of the original shape. However we still use the reduced (gathered) problem + // shape since it corresponds to the logical indexing in reduced size GEMM. + arguments{ + gemm_mode, + problem_shape, + { + tensor_a.get(), + stride_A_orig, + tensor_b.get(), + stride_B_orig + }, + { + { alpha, beta }, + tensor_c.get(), stride_C_orig, + tensor_d.get(), stride_D_orig, + typename Epilogue::GatherC {gather_indices.get()}, + typename Epilogue::ScatterD{gather_indices.get()} + }, + hw_info, + {}, + typename Kernel::GatherA{gather_indices.get()}, + typename Kernel::GatherB{gather_indices.get()} + }, + workspace(Gemm::get_workspace_size(arguments)), + gemm_ref(), + arguments_ref{ + gemm_mode, + problem_shape, + { + DoGatherA ? tensor_a_gathered.get() : tensor_a.get(), + stride_A, + DoGatherB ? tensor_b_gathered.get() : tensor_b.get(), + stride_B + }, + { + { alpha, beta }, + DoGatherC ? tensor_c_gathered.get() : tensor_c.get(), + stride_C, + DoScatterD ? tensor_d_gathered.get() : tensor_d_reference.get(), + stride_D + }, + hw_info + }, + workspace_ref(GemmRef::get_workspace_size(arguments_ref)), + gemm_opt(), + arguments_opt{ + gemm_mode, + problem_shape, + { + DoGatherA ? tensor_a_gathered.get() : tensor_a.get(), + stride_A, + DoGatherB ? tensor_b_gathered.get() : tensor_b.get(), + stride_B + }, + { + { alpha, beta }, + DoGatherC ? tensor_c_gathered.get() : tensor_c.get(), + stride_C, + DoScatterD ? tensor_d_gathered.get() : tensor_d_reference.get(), + stride_D + }, + hw_info + }, + workspace_opt(GemmOpt::get_workspace_size(arguments_opt)) + { + // Fill input and output matrices on host using CUTLASS helper functions + cutlass::reference::device::BlockFillRandomUniform(tensor_a.get(), tensor_a.size(), 1, ElementA(7), ElementA(-8), 0); + cutlass::reference::device::BlockFillRandomUniform(tensor_b.get(), tensor_b.size(), 1, ElementB(7), ElementB(-8), 0); + cutlass::reference::device::BlockFillRandomUniform(tensor_c.get(), tensor_c.size(), 1, ElementC(7), ElementC(-8), 0); + cutlass::reference::device::BlockFillSequential(tensor_d.get(), tensor_d.size(), ElementD(0), ElementD(0)); + + // <- Fill gather_indices with unique random integers in range [0,n) + int index_range = GatherModeM ? problem_size_orig.m() : (GatherModeN ? problem_size_orig.n() : problem_size_orig.k()); + std::vector indices(index_range); + std::iota(indices.begin(), indices.end(), 0); + { // std::random_shuffle was deprecated in C++14 and removed in C++17 + std::random_device make_seed; + std::mt19937 source_of_randomness(make_seed()); + std::shuffle(indices.begin(), indices.end(), source_of_randomness); + } + gather_indices.copy_from_host(indices.data()); + + auto const gemm_init = [](auto & gemm, auto const & arguments, auto & workspace) + { + cutlass::Status status = gemm.can_implement(arguments); + CUTLASS_CHECK(status); + status = gemm.initialize(arguments, workspace.get()); + CUTLASS_CHECK(status); + }; + + gemm_init(gemm, arguments, workspace ); + gemm_init(gemm_ref, arguments_ref, workspace_ref); + gemm_init(gemm_opt, arguments_opt, workspace_opt); + } + + void debug_output(std::ostream & os) + { + auto print_tensor = [](std::ostream &os, char const * name, auto const & data, auto shape, auto stride) + { + std::vector> h_data(data.size()); + data.copy_to_host(h_data.data()); + Tensor t = make_tensor(h_data.data(), shape, stride); + os << "\n" << name << ": " << std::setw(4) << t << std::endl; + }; + { + auto [M,N,K,L] = problem_shape_orig; + print_tensor(os, "A", tensor_a, make_shape(M,K,L), stride_A_orig); + print_tensor(os, "B", tensor_b, make_shape(N,K,L), stride_B_orig); + print_tensor(os, "C", tensor_c, make_shape(M,N,L), stride_C_orig); + print_tensor(os, "D", tensor_d, make_shape(M,N,L), stride_D_orig); + print_tensor(os, "D reference", tensor_d_reference, make_shape(M,N,L), stride_D_orig); + print_tensor(os, "indices", gather_indices, make_shape(gather_indices.size()), make_stride(_1{})); + } + } + + template + static void run_gemm(Gemm2 &gemm) + { + cutlass::Status status = gemm.run(); + CUTLASS_CHECK(status); + } + + template + void run_reference(Gemm2 &gemm) + { + // Convenience wrapper around calls to separate gather/scatter kernels + auto run_gather = [this](auto call, auto const & input, auto & output, auto gather_func, auto batch_size, auto stride) + { + [[maybe_unused]] auto idx = find_if(stride, [](auto x){ return not is_constant<1, decltype(x)>{}; }); + constexpr int I = decltype(idx)::value; + call(input.get(), + output.get(), + gather_func, + batch_size, + static_cast(input.size() / batch_size), + static_cast(output.size() / batch_size), + static_cast(get(stride)), + hw_info); + }; + + // Forward calls via lambda to avoid specifying template arguments + auto gather_call = [](auto&&... args){ gather(static_cast(args)...); }; + // MSVC doesn't count use inside a false "if constexpr" branch. + [[maybe_unused]] auto scatter_call = [](auto&&... args){ scatter(static_cast(args)...); }; + + if constexpr (DoGatherA) { + run_gather(gather_call, tensor_a, tensor_a_gathered, arguments.gather_A, problem_size.batch(), stride_A); + } + if constexpr (DoGatherB) { + run_gather(gather_call, tensor_b, tensor_b_gathered, arguments.gather_B, problem_size.batch(), stride_B); + } + if constexpr (DoGatherC) { + if (beta != ElementComputeEpilogue(0)) { + run_gather(gather_call, tensor_c, tensor_c_gathered, arguments.epilogue.gather_C, problem_size.batch(), stride_C); + } + } + + run_gemm(gemm); + + if constexpr (DoScatterD) { + run_gather(scatter_call, tensor_d_gathered, tensor_d_reference, arguments.epilogue.scatter_D, problem_size.batch(), stride_D); + } + } + + bool verify() + { + run_gemm(gemm); + run_reference(gemm_ref); + cudaDeviceSynchronize(); + return cutlass::reference::device::BlockCompareEqual(tensor_d.get(), tensor_d_reference.get(), tensor_d.size()); + } + + bool run(Options const &options) + { + if (options.reference_check) { + if (!verify()) { + std::cout << "Failed validation" << std::endl; +#if 0 + debug_output(std::cout); +#endif + return false; + } + else { + std::cout << "Passed validation" << std::endl; + } + } + + // + // Run profiling loop + // + + auto const benchmark = [&](auto name, auto func) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + func(); + } + timer.stop(); + + double runtime = timer.elapsed_millis() / double(options.iterations); + double gflops = 2 * double(problem_size.product()) / 1e6 / runtime; // Two flops per multiply-add + + std::cout << name << ":\n"; + std::cout << " Runtime: " << runtime << " ms\n"; + std::cout << " GFLOPs: " << gflops << "\n"; + }; + + benchmark("Fused", [&](){ run_gemm(gemm); }); + benchmark("Unfused default", [&](){ run_reference(gemm_ref); }); + benchmark("Unfused optimized", [&](){ run_reference(gemm_opt); }); + + return true; + } +}; + +} // namespace example + +int main(int argc, const char ** argv) { + + bool notSupported = false; + + // CUDA 12 minimum required + if (__CUDACC_VER_MAJOR__ < 12) { + std::cerr << "This example requires CUDA Toolkit version 12 or later.\n"; + notSupported = true; + } + + cudaDeviceProp props; + CUDA_CHECK(cudaGetDeviceProperties(&props, 0)); + + if (props.major < 9) { + std::cerr << "This example requires a device with compute capability 90 or higher.\n"; + notSupported = true; + } + if (notSupported) { + return EXIT_SUCCESS; // Do not fail CI checks on unsupported systems + } + + example::Options options; + options.parse(argc, argv); + + if (options.help) { + options.print_usage(std::cout) << "\n"; + return EXIT_SUCCESS; + } + + if (!options.valid()) { + std::cerr << "Invalid arguments." << "\n"; + return EXIT_FAILURE; + } + + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + bool result = true; + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + + switch (options.mode) { + using namespace example; + case 0: { + std::cout << "Gather A,C + scatter D on M mode:" << std::endl; + using Runner = ExampleRunner< + cutlass::half_t, cutlass::layout::RowMajor, IndexedGather, // A + cutlass::half_t, cutlass::layout::ColumnMajor, NoGather, // B + cutlass::half_t, cutlass::layout::RowMajor, IndexedGather, // C + cutlass::half_t, cutlass::layout::RowMajor, IndexedGather, // D + float, float>; + result &= Runner(options, hw_info).run(options); + break; + } + case 1: { + std::cout << "Gather B,C + scatter D on N mode:" << std::endl; + using Runner = ExampleRunner< + cutlass::half_t, cutlass::layout::RowMajor, NoGather, // A + cutlass::half_t, cutlass::layout::ColumnMajor, IndexedGather, // B + cutlass::half_t, cutlass::layout::ColumnMajor, IndexedGather, // C + cutlass::half_t, cutlass::layout::ColumnMajor, IndexedGather, // D + float, float>; + result &= Runner(options, hw_info).run(options); + break; + } + case 2: { + std::cout << "Gather A,B on K mode:" << std::endl; + using Runner = ExampleRunner< + cutlass::half_t, cutlass::layout::ColumnMajor, IndexedGather, // A + cutlass::half_t, cutlass::layout::RowMajor, IndexedGather, // B + cutlass::half_t, cutlass::layout::RowMajor, NoGather, // C + cutlass::half_t, cutlass::layout::RowMajor, NoGather, // D + float, float>; + result &= Runner(options, hw_info).run(options); + break; + } + } +#endif + + return result ? EXIT_SUCCESS : EXIT_FAILURE; +} diff --git a/examples/52_hopper_gather_scatter_fusion/CMakeLists.txt b/examples/52_hopper_gather_scatter_fusion/CMakeLists.txt new file mode 100644 index 0000000000..bf67537002 --- /dev/null +++ b/examples/52_hopper_gather_scatter_fusion/CMakeLists.txt @@ -0,0 +1,32 @@ +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +cutlass_example_add_executable( + 52_hopper_gather_scatter_fusion + 52_hopper_gather_scatter_fusion.cu + ) diff --git a/examples/52_hopper_gather_scatter_fusion/gather_gemm.hpp b/examples/52_hopper_gather_scatter_fusion/gather_gemm.hpp new file mode 100644 index 0000000000..c71109aa79 --- /dev/null +++ b/examples/52_hopper_gather_scatter_fusion/gather_gemm.hpp @@ -0,0 +1,421 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/dispatch_policy.hpp" + +#include "cute/tensor.hpp" + +#include "gather_tensor.hpp" + +namespace cutlass { + ///Forward declaration + struct CudaHostAdapter; +} + +namespace cutlass::gemm::kernel { + +/////////////////////////////////////////////////////////////////////////////// + +template < + class ProblemShape_, + class CollectiveMainloop_, + class CollectiveEpilogue_, + class TileScheduler_, + class GatherA_, + class GatherB_ +> +class GemmGather +{ +public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4, + "ProblemShape{} should be or "); + + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShape = typename CollectiveMainloop::TileShape; + using TiledMma = typename CollectiveMainloop::TiledMma; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementA = typename CollectiveMainloop::ElementA; + using StrideA = typename CollectiveMainloop::StrideA; + using ElementB = typename CollectiveMainloop::ElementB; + using StrideB = typename CollectiveMainloop::StrideB; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using ClusterShape = typename DispatchPolicy::ClusterShape; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + static_assert(ArchTag::kMinComputeCapability >= 90); + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using ElementC = typename CollectiveEpilogue::ElementC; + using StrideC = typename CollectiveEpilogue::StrideC; + using ElementD = typename CollectiveEpilogue::ElementD; + using StrideD = typename CollectiveEpilogue::StrideD; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + + static_assert(cute::is_void_v or cute::is_same_v, + "Non-persistent warp-specialized kernel does not support specializing the tile scheduler."); + using TileSchedulerTag = TileScheduler_; + using TileScheduler = typename detail::TileSchedulerSelector< + TileScheduler_, ArchTag, TileShape, ClusterShape>::Scheduler; + using TileSchedulerArguments = typename TileScheduler::Arguments; + + using GatherA = GatherA_; + using GatherB = GatherB_; + + // Kernel level shared memory storage + struct SharedStorage { + union TensorStorage { + using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + + MainloopTensorStorage mainloop; + EpilogueTensorStorage epilogue; + } tensors; + + struct PipelineStorage : cute::aligned_struct<16, _2> { + using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; + using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; + + alignas(16) MainloopPipelineStorage mainloop; + alignas(16) EpiLoadPipelineStorage epi_load; + } pipelines; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + using GmemTiledCopyA = typename CollectiveMainloop::GmemTiledCopyA; + using GmemTiledCopyB = typename CollectiveMainloop::GmemTiledCopyB; + static_assert(cute::size(GmemTiledCopyA{}) == cute::size(GmemTiledCopyB{}), "Number of threads in A/B tiled copies must be the same."); + + static constexpr uint32_t NumLoadWarpGroups = cute::size(GmemTiledCopyA{}) / NumThreadsPerWarpGroup; + static constexpr uint32_t NumMmaWarpGroups = CUTE_STATIC_V(cute::size(TiledMma{})) / NumThreadsPerWarpGroup; + static constexpr uint32_t NumWarpGroups = NumLoadWarpGroups + NumMmaWarpGroups; + static_assert(NumWarpGroups == 2 || NumWarpGroups == 3, "Number of warp groups must be 2 or 3 for good performance."); + + static constexpr uint32_t MaxThreadsPerBlock = NumWarpGroups * NumThreadsPerWarpGroup; + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + + // Device side arguments + struct Arguments { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; + GatherA gather_A{}; + GatherB gather_B{}; + }; + + // Kernel entry point API + struct Params { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopParams mainloop{}; + EpilogueParams epilogue{}; + GatherA gather_A{}; + GatherB gather_B{}; + }; + + // + // Methods + // + + // Convert to underlying arguments. In this case, a simple copy for the aliased type. + static + Params + to_underlying_arguments(Arguments const& args, void* workspace) { + (void) workspace; + auto problem_shape = args.problem_shape; + if constexpr (detail::Has_SwapAB_v) { + // swap M/N + get<0>(problem_shape) = get<1>(args.problem_shape); + get<1>(problem_shape) = get<0>(args.problem_shape); + } + return { + args.mode, + problem_shape, + CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace), + CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace), + args.gather_A, + args.gather_B + }; + } + + static bool + can_implement(Arguments const& args) { + bool implementable = (args.mode == GemmUniversalMode::kGemm) or + (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); + return implementable; + } + implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); + implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); + return implementable; + } + + static + size_t + get_workspace_size(Arguments const& args) { + return 0; + } + + static + cutlass::Status + initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + return Status::kSuccess; + } + + // Computes the kernel launch grid shape based on runtime parameters + static dim3 + get_grid_shape(Params const& params) { + auto cluster_shape = Shape<_1,_1,_1>{}; + auto tile_shape = TileShape{}; + auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); + return TileScheduler::get_tiled_cta_shape_mnl( + problem_shape_MNKL, tile_shape, cluster_shape); + } + + static dim3 + get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + CUTLASS_DEVICE + void + operator()(Params const& params, char* smem_buf) { + using namespace cute; + using X = Underscore; + + // Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. + #if ! defined(__CUDA_ARCH_FEAT_SM90_ALL) + if constexpr(size<0>(typename TiledMma::AtomShape_MNK{}) == 64) { + printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); + return; + } + #endif + + enum class WarpGroupRole { + Producer = 0, + Consumer = 1, + }; + + // Kernel level shared memory storage + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + int thread_idx = int(threadIdx.x); + int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup; + int warp_group_idx = canonical_warp_group_idx(); + CUTLASS_ASSERT(warp_group_idx < NumWarpGroups); + WarpGroupRole warp_group_role = warp_group_idx < NumLoadWarpGroups ? WarpGroupRole::Producer : WarpGroupRole::Consumer; + + // Mainloop Load pipeline + using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; + typename MainloopPipeline::Params mainloop_pipeline_params; + if (warp_group_role == WarpGroupRole::Producer) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::Consumer) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; + } + mainloop_pipeline_params.producer_arv_count = NumLoadWarpGroups * NumThreadsPerWarpGroup; + mainloop_pipeline_params.consumer_arv_count = NumMmaWarpGroups * NumThreadsPerWarpGroup; + MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params); + + // Epilogue Load pipeline + using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; + typename EpiLoadPipeline::Params epi_load_pipeline_params; + if (warp_group_role == WarpGroupRole::Producer) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::Consumer) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; + } + epi_load_pipeline_params.producer_arv_count = NumLoadWarpGroups * NumThreadsPerWarpGroup; + epi_load_pipeline_params.consumer_arv_count = NumMmaWarpGroups * NumThreadsPerWarpGroup; + EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); + + // Epilogue Store pipeline + using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; + typename EpiStorePipeline::Params epi_store_pipeline_params; + epi_store_pipeline_params.always_wait = true; + EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); + + // Initialize starting pipeline states for the collectives + typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state; + typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state; + + // For the DMA Load (producer) we start with an opposite phase + // i.e., we skip all waits since we know that the buffer is indeed empty + PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state(); + PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); + PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); + + // Preconditions + static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + + // Separate out problem shape for convenience + // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); + auto M = get<0>(problem_shape_MNKL); + auto N = get<1>(problem_shape_MNKL); + auto K = get<2>(problem_shape_MNKL); + auto L = get<3>(problem_shape_MNKL); + + // Represent the full tensors + Tensor mA_mkl = make_gather_tensor(make_gmem_ptr(params.mainloop.ptr_A), make_shape(M,K,L), params.mainloop.dA, params.gather_A); //(m,k,l) + Tensor mB_nkl = make_gather_tensor(make_gmem_ptr(params.mainloop.ptr_B), make_shape(N,K,L), params.mainloop.dB, params.gather_B); //(n,k,l) + + // Get the appropriate blocks for this thread block -- potential for thread block locality + auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) + TiledMma tiled_mma; + + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, blk_shape, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, blk_shape, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + + // Compute m_coord, n_coord, and l_coord with their post-tiled shapes + auto m_coord = idx2crd(int(blockIdx.x), shape<2>(gA_mkl)); + auto n_coord = idx2crd(int(blockIdx.y), shape<2>(gB_nkl)); + auto l_coord = idx2crd(int(blockIdx.z), shape<4>(gB_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + + // Slice with m_coord and n_coord + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + + // Get pipeline iterators and increments from tensor shapes + auto k_tile_iter = cute::make_coord_iterator(shape<2>(gA)); + auto k_tile_count = size<2>(gA); + auto c_tile_count = CollectiveEpilogue::get_load_pipe_increment(blk_shape); + auto d_tile_count = CollectiveEpilogue::get_store_pipe_increment(blk_shape); + + // Wait for all threads in the thread block + __syncthreads(); + + // In a warp specialized kernel, collectives expose data movement and compute operations separately + CollectiveMainloop collective_mainloop; + CollectiveEpilogue collective_epilogue{params.epilogue, shared_storage.tensors.epilogue}; + + if (warp_group_role == WarpGroupRole::Producer) { + // Compute tile residues for predication + auto m_max_coord = M - size<0>(gA) * get<0>(blk_coord); // M - BLK_M * m_coord + auto n_max_coord = N - size<0>(gB) * get<1>(blk_coord); // N - BLK_N * n_coord + auto k_residue = K - size<1>(gA) * size<2>(gA); // K - BLK_K * k_coord_max + auto residue_mnk = make_tuple(m_max_coord, n_max_coord, k_residue); + + collective_mainloop.load( + mainloop_pipeline, + mainloop_pipe_producer_state, + gA, + gB, + k_tile_iter, k_tile_count, + residue_mnk, + thread_idx, + shared_storage.tensors.mainloop + ); + // Update starting mainloop pipeline state for the pipeline drain + mainloop_pipe_producer_state.advance(k_tile_count); + // Make sure mainloop consumer has been waited upon before issuing epilogue load + collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); + + if (collective_epilogue.is_producer_load_needed()) { + epi_load_pipe_producer_state = + collective_epilogue.load( + epi_load_pipeline, + epi_load_pipe_producer_state, + problem_shape_MNKL, + blk_shape, + blk_coord, + tiled_mma, + thread_idx, + shared_storage.tensors.epilogue + ); + collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state); + } + } + else if (warp_group_role == WarpGroupRole::Consumer) { + Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N) + + collective_mainloop.mma( + mainloop_pipeline, + mainloop_pipe_consumer_state, + accumulators, + k_tile_count, + warp_group_thread_idx, + shared_storage.tensors.mainloop, + params.mainloop + ); + + // Make sure the math instructions are done and free buffers before entering the epilogue + collective_mainloop.mma_tail( + mainloop_pipeline, + mainloop_pipe_consumer_state, + k_tile_count + ); + + // Epilogue and write to gD + collective_epilogue.store( + epi_load_pipeline, + epi_load_pipe_consumer_state, + epi_store_pipeline, + epi_store_pipe_producer_state, + problem_shape_MNKL, + blk_shape, + blk_coord, + accumulators, + tiled_mma, + warp_group_thread_idx, + shared_storage.tensors.epilogue + ); + } + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel diff --git a/examples/52_hopper_gather_scatter_fusion/gather_kernel.cuh b/examples/52_hopper_gather_scatter_fusion/gather_kernel.cuh new file mode 100644 index 0000000000..592bf57e39 --- /dev/null +++ b/examples/52_hopper_gather_scatter_fusion/gather_kernel.cuh @@ -0,0 +1,136 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/numeric/math.hpp" + +namespace example +{ + +// Naive grid-stride loop implementation of gather +template +__global__ void +gather_kernel(Element const * __restrict__ input, + Element * __restrict__ output, + Func func, + int num_elems_input, + int num_elems_output, + cutlass::FastDivmod stride_divmod) +{ + Element const * input_b = input + blockIdx.z * num_elems_input; + Element * output_b = output + blockIdx.z * num_elems_output; + int tidx = threadIdx.x + blockIdx.x * blockDim.x; + for (int k = tidx; k < num_elems_output; k += blockDim.x * gridDim.x) { + int i,j; + stride_divmod(j, i, k); + output_b[k] = input_b[i + func(j) * stride_divmod.divisor]; + } +} + +// Gather elements along strided dimension of the tensor according to given indices +template +void +gather(Element const * input, + Element * output, + Func func, + int batch_size, + int num_elems_input, + int num_elems_output, + int stride, + cutlass::KernelHardwareInfo const& hw_info) +{ + // Upcast to uint128_t data type + int factor = 128 / cutlass::sizeof_bits::value; + assert(stride % factor == 0); + int stride_upcast = stride/factor; + int num_elems_input_upcast = num_elems_input / factor; + int num_elems_output_upcast = num_elems_output / factor; + + cutlass::FastDivmod stride_divmod(stride_upcast); + dim3 blocks(hw_info.sm_count, 1, batch_size); + gather_kernel<<>>(reinterpret_cast(input), + reinterpret_cast(output), + func, + num_elems_input_upcast, + num_elems_output_upcast, + stride_divmod); +} + +// Naive grid-stride loop implementation of scatter +template +__global__ void +scatter_kernel(Element const * __restrict__ input, + Element * __restrict__ output, + Func func, + int num_elems_input, + int num_elems_output, + cutlass::FastDivmod stride_divmod) +{ + Element const * input_b = input + blockIdx.z * num_elems_input; + Element * output_b = output + blockIdx.z * num_elems_output; + int tidx = threadIdx.x + blockIdx.x * blockDim.x; + for (int k = tidx; k < num_elems_input; k += blockDim.x * gridDim.x) { + int i,j; + stride_divmod(j, i, k); + output_b[i + func(j) * stride_divmod.divisor] = input_b[k]; + } +} + +// Gather elements along strided dimension of the tensor according to given indices +template +void +scatter(Element const * input, + Element * output, + Func func, + int batch_size, + int num_elems_input, + int num_elems_output, + int stride, + cutlass::KernelHardwareInfo const& hw_info) +{ + // Upcast to uint128_t data type + int factor = 128 / cutlass::sizeof_bits::value; + assert(stride % factor == 0); + int stride_upcast = stride/factor; + int num_elems_input_upcast = num_elems_input / factor; + int num_elems_output_upcast = num_elems_output / factor; + + cutlass::FastDivmod stride_divmod(stride_upcast); + dim3 blocks(hw_info.sm_count, 1, batch_size); + scatter_kernel<<>>(reinterpret_cast(input), + reinterpret_cast(output), + func, + num_elems_input_upcast, + num_elems_output_upcast, + stride_divmod); +} + +} // namespace example diff --git a/examples/52_hopper_gather_scatter_fusion/scatter_epilogue.hpp b/examples/52_hopper_gather_scatter_fusion/scatter_epilogue.hpp new file mode 100644 index 0000000000..dc9c0df804 --- /dev/null +++ b/examples/52_hopper_gather_scatter_fusion/scatter_epilogue.hpp @@ -0,0 +1,222 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing elementwise operations used by epilogues. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/detail.hpp" + +#include "cute/tensor.hpp" +#include "cute/numeric/numeric_types.hpp" + +#include "gather_tensor.hpp" + +namespace cutlass::epilogue::collective { + +/// Applies an element wise operation to all elements within the fragment +/// and scatter-writes them out to destination storage. +/// GatherC and ScatterD are types of user-defined functions that apply the +/// transoformation of the strided coordinate (e.g. through an index array). +template < + class StrideC_, + class StrideD_, + class ThreadEpilogueOp_, + class EpilogueSchedule_, + class GatherC_, + class ScatterD_ +> +class EpilogueGatherScatter { +public: + // + // Type Aliases + // + using EpilogueSchedule = EpilogueSchedule_; + + // derived types of output thread level operator + using ThreadEpilogueOp = ThreadEpilogueOp_; + using ElementOutput = typename ThreadEpilogueOp::ElementOutput; + using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator; + using ElementCompute = typename ThreadEpilogueOp::ElementCompute; + using ElementScalar = ElementCompute; + using ElementC = typename ThreadEpilogueOp::ElementC; + using StrideC = StrideC_; + using ElementD = typename ThreadEpilogueOp::ElementD; + using StrideD = StrideD_; + + // Every epilogue needs these two GmemTiledCopy{C,D} aliases. + // If you don't know what they should be, just use void. + using GmemTiledCopyC = void; + using GmemTiledCopyD = void; + + using GatherC = GatherC_; + using ScatterD = ScatterD_; + + static const int kOutputAlignment = ThreadEpilogueOp::kCount; + using AlignmentType = typename cute::uint_bit::value * kOutputAlignment>::type; + + static_assert(cute::rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static_assert(cute::rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + + struct SharedStorage { }; + + // Host side epilogue arguments + struct Arguments { + typename ThreadEpilogueOp::Params thread_params{}; + ElementC const* ptr_C = nullptr; + StrideC dC{}; + ElementD* ptr_D = nullptr; + StrideD dD{}; + GatherC gather_C{}; + ScatterD scatter_D{}; + }; + + // Device side epilogue params + using Params = Arguments; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments( + [[maybe_unused]] ProblemShape const& _, + Arguments const& args, + [[maybe_unused]] void* workspace) { + return args; + } + + template + static bool + can_implement( + [[maybe_unused]] ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + return true; + } + + CUTLASS_HOST_DEVICE + EpilogueGatherScatter(Params const& params_) : params(params_) { } + + template< + class ProblemShapeMNKL, + class BlockShapeMNK, + class BlockCoordMNKL, + class FrgEngine, class FrgLayout, + class TiledMma, + class ResidueMNK + > + CUTLASS_DEVICE void + operator()( + ProblemShapeMNKL problem_shape_mnkl, + BlockShapeMNK blk_shape_MNK, + BlockCoordMNKL blk_coord_mnkl, + cute::Tensor const& accumulators, + TiledMma tiled_mma, + ResidueMNK residue_mnk, + int thread_idx, + char* smem_buf) + { + using namespace cute; + using X = Underscore; + + static_assert(cute::rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(is_static::value, "ThreadBlock tile shape must be static"); + static_assert(cute::rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); + static_assert(cute::rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3"); + + (void) smem_buf; + ThreadEpilogueOp epilogue_op{params.thread_params}; + + // Separate out problem shape for convenience + auto M = get<0>(problem_shape_mnkl); + auto N = get<1>(problem_shape_mnkl); + auto L = get<3>(problem_shape_mnkl); + + auto stride_c = detail::get_epilogue_stride(params.dC); + auto stride_d = detail::get_epilogue_stride(params.dD); + + // Represent the full output tensor + Tensor mC_mnl = make_gather_tensor(make_gmem_ptr(params.ptr_C), make_shape(M,N,L), stride_c, params.gather_C); // (m,n,l) + Tensor mD_mnl = make_gather_tensor(make_gmem_ptr(params.ptr_D), make_shape(M,N,L), stride_d, params.scatter_D); // (m,n,l) + + Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + + // Slice to get the tile this CTA is responsible for + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; + Tensor gC = gC_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) + Tensor gD = gD_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) + + // Partition source and destination tiles to match the accumulator partitioning + auto thr_mma = tiled_mma.get_thread_slice(thread_idx); + Tensor tCgD = thr_mma.partition_C(gD); // (VEC,THR_M,THR_N) + Tensor tCgC = thr_mma.partition_C(gC); // (VEC,THR_M,THR_N) + + static_assert(is_static::value, "Accumulator layout must be static"); + CUTE_STATIC_ASSERT_V(size(tCgC) == size(tCgD), + "Source and destination must have the same number of elements."); + CUTE_STATIC_ASSERT_V(size(tCgD) == size(accumulators), + "Accumulator count must have the same destination element count."); + + // Make an identity coordinate tensor for predicating our output MN tile + auto cD = make_identity_tensor(make_shape(unwrap(shape<0>(gD)), unwrap(shape<1>(gD)))); + Tensor tCcD = thr_mma.partition_C(cD); + + // source is needed + if (epilogue_op.is_source_needed()) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accumulators); ++i) { + if (elem_less(tCcD(i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { + tCgD(i) = epilogue_op(accumulators(i), tCgC(i)); + } + } + } + // source is not needed, avoid load + else { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accumulators); ++i) { + if (elem_less(tCcD(i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { + tCgD(i) = epilogue_op(accumulators(i)); + } + } + } + } + +private: + Params params; +}; + +} // namespace cutlass::epilogue::collective + diff --git a/examples/53_hopper_gemm_permute/53_hopper_gemm_permute.cu b/examples/53_hopper_gemm_permute/53_hopper_gemm_permute.cu new file mode 100644 index 0000000000..d24c5f294a --- /dev/null +++ b/examples/53_hopper_gemm_permute/53_hopper_gemm_permute.cu @@ -0,0 +1,979 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Hopper GEMM+permute example. + + This example demonstrates the fusion of tensor permutation operations with a Hopper GEMM kernel. + It is similar in spirit to example 39_gemm_permute, but uses CUTLASS 3 CollectiveBuilder API to + construct kernels that make use of Hopper architecture features: Tensor Memory Accelerator (TMA) + units and warpgroup-level MMA instructions. + + Background + ---------- + + While a GEMM kernel computes a product of two matrices (rank-2 tensors), the source data may + come from higher-rank tensors by combining some if its modes (dimensions) into the row and column + modes of the matrix. These tensors are often outputs from previous layers of a network, and the + data may sometimes need to be reordered in memory before a GEMM is computed. Similarly, the output + of a GEMM may need to be reordered before a subsequent operation can be executed. + + Consider this sample PyTorch code: + + # Forward pass + D = torch.mm(A, B).view(M/D1, D1, D2, N/D2).permute(0, 2, 1, 3) + + # Backward pass + grad_A = torch.mm(grad_D.permute(0, 2, 1, 3).view(M, N), B) + + Executing the reordering as a separate operation requires committing intermediate tensor to memory + and increases the latency and memory footprint of the model. By fusing the permutation with either + reading of A/B matrices or writing of D matrix, we can avoid the unnecessary global memory traffic + and kernel launch overhead. + + Implementation + -------------- + + The approach relies on two things: + - The ability of CUTLASS 3 to naturally perform general tensor contractions (GETT) owing to the + flexibility of CuTe's hierarchical layouts (see example 51_hopper_gett for more details). + - The harware capabilities of Hopper TMA units that allow for loading multidimensional tensors with + (almost) arbitrary strides, which can be used to represent a permuted view of the data. + + In this example we reuse the permutation classes of examples 39_gemm_permute as operation tags. + For each tag, a specialization of struct PermuteTraits<> provides the necessary information about + the target tensor shape and ordering of modes. The main class, ExampleRunner, then figures out the + overall (hierarchical) shape of the GEMM operation and computes the shape and strides for each + tensor taking into account the permutation applied. We highlight the importance of specifying + consistent multidimensional shapes for all tensors (even those that are not permuted), as well as + choosing hierarchical GEMM tile sizes that best fit those shapes (in cases where some tensor + dimensions are known at compile time). + + In addition, this example implements a standalone permutation kernel that is used to both verify + correctness of the fused kernel and benchmark the fused kernel against an unfused version that + writes intermediate tensor to memory. +*/ + +#include "cutlass/arch/arch.h" +#include "cutlass/arch/mma.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/permute.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_compare.h" + +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" + +#include "helper.h" +#include "permute_kernel.cuh" +#include "permute_traits.hpp" + +namespace example +{ + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +struct Options { + + bool help; + + cutlass::gemm::BatchedGemmCoord problem_size; + + float alpha; + float beta; + + bool reference_check; + int iterations; + + bool verbose; + + Options(): + help(false), + problem_size({2048, 2048, 2048, 8}), + alpha(1.0), + beta(1.0), + reference_check(true), + iterations(20), + verbose(false) { } + + bool valid() const { + return problem_size.m() > 0 + && problem_size.n() > 0 + && problem_size.k() > 0 + && problem_size.batch() > 0 + && iterations > 0; + } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + } + + cmd.get_cmd_line_argument("m", problem_size.m()); + cmd.get_cmd_line_argument("n", problem_size.n()); + cmd.get_cmd_line_argument("k", problem_size.k()); + cmd.get_cmd_line_argument("batch_size", problem_size.batch()); + + cmd.get_cmd_line_argument("alpha", alpha); + cmd.get_cmd_line_argument("beta", beta); + + cmd.get_cmd_line_argument("check", reference_check, true); + cmd.get_cmd_line_argument("iterations", iterations); + + cmd.get_cmd_line_argument("verbose", verbose, false); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << + "53_hopper_gemm_permute example\n" + "\n" + " This example uses the CUTLASS Library to fuse permute() on input/output tensors with GEMM\n" + "\n" + "Options:\n" + " --help If specified, displays this usage statement.\n" + " --m= GEMM M dimension\n" + " --n= GEMM N dimension\n" + " --k= GEMM K dimension\n" + " --alpha= GEMM alpha parameter\n" + " --beta= GEMM beta parameter\n" + " --iterations= Number of profiling iterations to perform.\n" + " --check= Validate results against a reference (unfused) imlementation" + " --verbose= Enable verbose output" + "\n" + "Examples:\n" + "\n" + "$ ./examples/53_hopper_gemm_permute/53_hopper_gemm_permute --m=4096 --n=2048 --k=3072 --batch_size=8\n"; + + return out; + } +}; + +using namespace cute; + +// Check the shapes assigned to the same mode of different tensors, +// ensure all permuted shapes are the same and return that shape. +template +auto +select_mode_shape(Shapes const & ... shapes) { + auto permuted_shapes = filter_tuple(cute::make_tuple(shapes...), [](auto shape) { + if constexpr (cute::rank(shape) > 1) { + return cute::make_tuple(shape); + } + else { + return cute::make_tuple(); + } + }); + if constexpr (cute::rank(permuted_shapes) == 0) { + return get<0>(cute::make_tuple(shapes...)); + } + else { + auto ref_shape = get<0>(permuted_shapes); + for_each(permuted_shapes, [&](auto shape) { + // This static assert fails to compile on GCC 7.5 + // static_assert(is_same::value, "Inconsistent shapes for the same mode"); + // This runtime check can be skipped if all permutations are required to be static. + if (shape != ref_shape) + { + print("Inconsistent shapes for the same mode: "); + print(ref_shape); print(" and "); print(shape); print("\n"); + exit(EXIT_FAILURE); + } + }); + return ref_shape; + } +} + +template +auto +compute_default_stride(Shape const & shape, StrideOrig const & stride_orig) { + // Only supports column-major and row-major, batch stride always comes last + if constexpr (is_constant<1, decltype(get<0>(stride_orig))>::value) { + return compact_col_major(shape); + } + else + { + return compact_order(shape, Step<_1,_0,_2>{}); + } +} + +// Divide a static scalar TileSize into static modes of Shape until either: +// - a dynamic mode is encountered +// - we run out of size to divide +// - no longer divisible by next shape +// Examples: +// select_tile_shape(_128, (_8,_16)) -> (_8,_16) +// select_tile_shape(_128, (_8,_32)) -> (_8,_16) +// select_tile_shape(_128, (_8, _4)) -> (_8,_4,_4) +// select_tile_shape(_128, (_8, 4)) -> (_8,_16) +template +auto +select_tile_shape(TileSize size, Shape const& shape) +{ + static_assert(is_static::value, "Tile size must be static"); + if constexpr (cute::rank(Shape{}) == 0) { + return cute::make_tuple(size); + } + else { + if constexpr (is_static>::value) { + auto div = front(shape); + if constexpr (size > div and size % div == 0) { + return prepend(select_tile_shape(size / div, take<1,tuple_size_v>(shape)), div); + } + else { + return cute::make_tuple(size); + } + } + else { + return cute::make_tuple(size); + } + } +} + +template +class ExampleRunner +{ +private: + + // Define shapes for each operand and original GEMM problem as a whole. + + using MatrixShape = Shape; // [M,N,L]/[M,K,L]/[N,K,L] + using ProblemShape = Shape; // [M,N,K,L] + + // Determine the CuTe stride for each of the four operands. + + using StrideA = cutlass::gemm::TagToStrideA_t; + using StrideB = cutlass::gemm::TagToStrideB_t; + using StrideC = cutlass::gemm::TagToStrideC_t; + using StrideD = cutlass::gemm::TagToStrideC_t; + + // Flags to check which operands will be permuted. + + static constexpr bool DoPermuteA = not cutlass::layout::is_trivial_permute; + static constexpr bool DoPermuteB = not cutlass::layout::is_trivial_permute; + static constexpr bool DoPermuteC = not cutlass::layout::is_trivial_permute; + static constexpr bool DoPermuteD = not cutlass::layout::is_trivial_permute; + + // For input operands, we must use inverse of the permutation operation + // to read data that is stored in original (un-permuted) order. + + using PermuteAReal = typename cutlass::layout::InversePermute::type; + using PermuteBReal = typename cutlass::layout::InversePermute::type; + using PermuteCReal = typename cutlass::layout::InversePermute::type; + using PermuteDReal = PermuteD; + + // Get permutation layout for each operand. + // A permutation layout is a rank-3 layout in the usual CuTe mode ordering, + // but each mode may have a nested shape corresponding to the reshaping of + // the matrix into a multidimensional tensor, and the strides are computed + // taking the desired permutation into account. + + template + using LayoutPermute = remove_cvref_t(make_layout(MatrixShape{}, Stride{})))>; + + using LayoutAP = LayoutPermute; + using LayoutBP = LayoutPermute; + using LayoutCP = LayoutPermute; + using LayoutDP = LayoutPermute; + + // Now we want to build the unified problem shape for permute-GEMM. + // To do this, we check the corresponding mode in each tensor that has it. + // If at least one tensor has a mode that has been reshaped (i.e. rank > 1), + // its shape will be used as the reference shape for that mode in all tensors. + // If multiple tensors have reshaped mode, we additionally check that their + // shapes for that mode match. Otherwise, we can't define a consistent GEMM shape. + + using ShapeM = decltype(select_mode_shape(shape<0>(LayoutAP{}), shape<0>(LayoutCP{}), shape<0>(LayoutDP{}))); + using ShapeN = decltype(select_mode_shape(shape<0>(LayoutBP{}), shape<1>(LayoutCP{}), shape<1>(LayoutDP{}))); + using ShapeK = decltype(select_mode_shape(shape<1>(LayoutAP{}), shape<1>(LayoutBP{}))); + using ShapeL = decltype(select_mode_shape(shape<2>(LayoutAP{}), shape<2>(LayoutBP{}), shape<2>(LayoutCP{}), shape<2>(LayoutDP{}))); + + using ProblemShapePermute = Shape; + + using ShapeAPermute = Shape; + using ShapeBPermute = Shape; + using ShapeCPermute = Shape; + using ShapeDPermute = Shape; + + // Next, we must define the strides for each tensor. + // If the tensor is permuted, we take the strides produced by the permutation function. + // Otherwise, we compute default strides induced by the new (multidimensional) shape of the tensor. + // + // This won't always work in general if multiple tensors are permuted: e.g. if PermuteA affects + // modes M and K, and PermuteB affects modes N and L, the single stride for mode L of tensor A + // computed by PermuteA will be non-congruent with it's shape that is changed by PermuteB. + // To handle this correctly, a more complicated logic is needed to reconstruct multi-mode strides. + // This is not addressed here, as it's not a common requirement to permute multiple tensors in one GEMM. + + using StrideAPermute = conditional_t, decltype(compute_default_stride(ShapeAPermute{}, StrideA{}))>; + using StrideBPermute = conditional_t, decltype(compute_default_stride(ShapeBPermute{}, StrideB{}))>; + using StrideCPermute = conditional_t, decltype(compute_default_stride(ShapeCPermute{}, StrideC{}))>; + using StrideDPermute = conditional_t, decltype(compute_default_stride(ShapeDPermute{}, StrideD{}))>; + + // We need to select optimal tile shape based on the tile size specified by the user. + // This is done by dividing the tile size in each mode by the mode shape as much + // as possible (i.e. until we run out of tile size or encounter a dynamic sub-shape). + + using TileMPermute = decltype(select_tile_shape(get<0>(TileShape{}), ShapeM{})); + using TileNPermute = decltype(select_tile_shape(get<1>(TileShape{}), ShapeN{})); + using TileKPermute = decltype(select_tile_shape(get<2>(TileShape{}), ShapeK{})); + + using TileShapePermute = Shape; + + // Now we are ready to define the GEMM kernel types for both fused permute and reference paths. + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementEpilogue, + ElementC, StrideC, 128 / cutlass::sizeof_bits::value, + ElementD, StrideD, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveEpiloguePermute = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShapePermute, ClusterShape, cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementEpilogue, + ElementC, StrideCPermute, 128 / cutlass::sizeof_bits::value, + ElementD, StrideDPermute, 128 / cutlass::sizeof_bits::value, + cutlass::epilogue::collective::EpilogueScheduleAuto + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, StrideA, 128 / cutlass::sizeof_bits::value, + ElementB, StrideB, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using CollectiveMainloopPermute = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, StrideAPermute, 128 / cutlass::sizeof_bits::value, + ElementB, StrideBPermute, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, + TileShapePermute, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpiloguePermute::SharedStorage))>, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using GemmKernelPermute = cutlass::gemm::kernel::GemmUniversal< + ProblemShapePermute, + CollectiveMainloopPermute, + CollectiveEpiloguePermute + >; + + using GemmReference = cutlass::gemm::device::GemmUniversalAdapter; + using GemmPermute = cutlass::gemm::device::GemmUniversalAdapter; + + // Data members + + cutlass::gemm::BatchedGemmCoord problem_size; + ProblemShape problem_shape; + cutlass::KernelHardwareInfo hw_info; + + ElementEpilogue alpha; + ElementEpilogue beta; + + MatrixShape shape_A; + MatrixShape shape_B; + MatrixShape shape_C; + MatrixShape shape_D; + + StrideA stride_A; + StrideB stride_B; + StrideC stride_C; + StrideD stride_D; + + LayoutAP layout_AP; + LayoutBP layout_BP; + LayoutCP layout_CP; + LayoutDP layout_DP; + + ShapeM shape_M; + ShapeN shape_N; + ShapeK shape_K; + ShapeL shape_L; + + ProblemShapePermute problem_shape_permute; + + ShapeAPermute shape_A_permute; + ShapeBPermute shape_B_permute; + ShapeCPermute shape_C_permute; + ShapeDPermute shape_D_permute; + + StrideAPermute stride_A_permute; + StrideBPermute stride_B_permute; + StrideCPermute stride_C_permute; + StrideDPermute stride_D_permute; + + cutlass::device_memory::allocation tensor_a; + cutlass::device_memory::allocation tensor_b; + cutlass::device_memory::allocation tensor_c; + cutlass::device_memory::allocation tensor_d; + + cutlass::device_memory::allocation tensor_a_permuted; + cutlass::device_memory::allocation tensor_b_permuted; + cutlass::device_memory::allocation tensor_c_permuted; + cutlass::device_memory::allocation tensor_d_unpermuted; + cutlass::device_memory::allocation tensor_d_reference; + + cutlass::gemm::GemmUniversalMode gemm_mode; + + GemmPermute gemm_permute; + typename GemmPermute::Arguments arguments_permute; + cutlass::device_memory::allocation workspace_permute; + + GemmReference gemm_reference; + typename GemmReference::Arguments arguments_reference; + cutlass::device_memory::allocation workspace_reference; + + public: + + ExampleRunner(Options const & options, cutlass::KernelHardwareInfo const & hw_info) + : problem_size(options.problem_size), + problem_shape(problem_size.m(), problem_size.n(), problem_size.k(), problem_size.batch()), + hw_info(hw_info), + alpha(options.alpha), + beta(options.beta), + shape_A(make_shape(problem_size.m(), problem_size.k(), problem_size.batch())), + shape_B(make_shape(problem_size.n(), problem_size.k(), problem_size.batch())), + shape_C(make_shape(problem_size.m(), problem_size.n(), problem_size.batch())), + shape_D(make_shape(problem_size.m(), problem_size.n(), problem_size.batch())), + stride_A(cutlass::make_cute_packed_stride(StrideA{}, shape_A)), + stride_B(cutlass::make_cute_packed_stride(StrideB{}, shape_B)), + stride_C(cutlass::make_cute_packed_stride(StrideC{}, shape_C)), + stride_D(cutlass::make_cute_packed_stride(StrideD{}, shape_D)), + layout_AP(make_permute_layout(make_layout(shape_A, stride_A))), + layout_BP(make_permute_layout(make_layout(shape_B, stride_B))), + layout_CP(make_permute_layout(make_layout(shape_C, stride_C))), + layout_DP(make_permute_layout(make_layout(shape_D, stride_D))), + shape_M(select_mode_shape(shape<0>(layout_AP), shape<0>(layout_CP), shape<0>(layout_DP))), + shape_N(select_mode_shape(shape<0>(layout_BP), shape<1>(layout_CP), shape<1>(layout_DP))), + shape_K(select_mode_shape(shape<1>(layout_AP), shape<1>(layout_BP))), + shape_L(select_mode_shape(shape<2>(layout_AP), shape<2>(layout_BP), shape<2>(layout_CP), shape<2>(layout_DP))), + problem_shape_permute(shape_M, shape_N, shape_K, shape_L), + shape_A_permute(make_shape(shape_M, shape_K, shape_L)), + shape_B_permute(make_shape(shape_N, shape_K, shape_L)), + shape_C_permute(make_shape(shape_M, shape_N, shape_L)), + shape_D_permute(make_shape(shape_M, shape_N, shape_L)), + stride_A_permute(conditional_return(layout_AP.stride(), compute_default_stride(shape_A_permute, stride_A))), + stride_B_permute(conditional_return(layout_BP.stride(), compute_default_stride(shape_B_permute, stride_B))), + stride_C_permute(conditional_return(layout_CP.stride(), compute_default_stride(shape_C_permute, stride_C))), + stride_D_permute(conditional_return(layout_DP.stride(), compute_default_stride(shape_D_permute, stride_D))), + tensor_a(problem_size.m() * problem_size.k() * problem_size.batch()), + tensor_b(problem_size.k() * problem_size.n() * problem_size.batch()), + tensor_c(problem_size.m() * problem_size.n() * problem_size.batch()), + tensor_d(problem_size.m() * problem_size.n() * problem_size.batch()), + tensor_a_permuted(problem_size.m() * problem_size.k() * problem_size.batch()), + tensor_b_permuted(problem_size.k() * problem_size.n() * problem_size.batch()), + tensor_c_permuted(problem_size.m() * problem_size.n() * problem_size.batch()), + tensor_d_unpermuted(problem_size.m() * problem_size.n() * problem_size.batch()), + tensor_d_reference(problem_size.m() * problem_size.n() * problem_size.batch()), + gemm_mode(problem_size.batch() > 1 ? cutlass::gemm::GemmUniversalMode::kBatched : cutlass::gemm::GemmUniversalMode::kGemm), + arguments_permute{ + gemm_mode, + problem_shape_permute, + { + tensor_a.get(), stride_A_permute, + tensor_b.get(), stride_B_permute, + }, + { + { alpha, beta }, + tensor_c.get(), stride_C_permute, + tensor_d.get(), stride_D_permute + }, + hw_info + }, + workspace_permute(GemmPermute::get_workspace_size(arguments_permute)), + arguments_reference{ + gemm_mode, + problem_shape, + { + DoPermuteA ? tensor_a_permuted.get() : tensor_a.get(), stride_A, + DoPermuteB ? tensor_b_permuted.get() : tensor_b.get(), stride_B + }, + { + { alpha, beta }, + DoPermuteC ? tensor_c_permuted.get() : tensor_c.get(), stride_C, + DoPermuteD ? tensor_d_unpermuted.get() : tensor_d_reference.get(), stride_D + }, + hw_info + }, + workspace_reference(GemmReference::get_workspace_size(arguments_reference)) + { + if (options.verbose) { + print("Original GEMM problem:\n"); + print(" Problem shape: "); print(problem_shape); print("\n"); + print(" Layout A: "); print(make_layout(shape_A, stride_A)); print("\n"); + print(" Layout B: "); print(make_layout(shape_B, stride_B)); print("\n"); + print(" Layout C: "); print(make_layout(shape_C, stride_C)); print("\n"); + print(" Layout D: "); print(make_layout(shape_D, stride_D)); print("\n"); + print(" Tile shape: "); print(TileShape{}); print("\n"); + print("With fused permutations:\n"); + print(" Problem shape: "); print(problem_shape_permute); print("\n"); + print(" Layout A: "); print(make_layout(shape_A_permute, stride_A_permute)); print("\n"); + print(" Layout B: "); print(make_layout(shape_B_permute, stride_B_permute)); print("\n"); + print(" Layout C: "); print(make_layout(shape_C_permute, stride_C_permute)); print("\n"); + print(" Layout D: "); print(make_layout(shape_D_permute, stride_D_permute)); print("\n"); + print(" Tile shape: "); print(TileShapePermute{}); print("\n"); + } + + cutlass::reference::device::BlockFillRandomUniform(tensor_a.get(), tensor_a.size(), 1, ElementA(7), ElementA(-8), 0); + cutlass::reference::device::BlockFillRandomUniform(tensor_b.get(), tensor_b.size(), 2, ElementB(7), ElementB(-8), 0); + cutlass::reference::device::BlockFillRandomUniform(tensor_c.get(), tensor_c.size(), 3, ElementC(7), ElementC(-8), 0); + cutlass::reference::device::BlockFillSequential(tensor_d.get(), tensor_d.size(), ElementD(0), ElementD(0)); + + auto const gemm_init = [](auto & gemm, auto const & arguments, auto & workspace) { + cutlass::Status status = gemm.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Requested GEMM kernel cannot be used for this problem.\n" + << "Check problem sizes and alignment requirements." << std::endl; + exit(EXIT_FAILURE); + } + status = gemm.initialize(arguments, workspace.get()); + CUTLASS_CHECK(status); + }; + + gemm_init(gemm_permute, arguments_permute, workspace_permute ); + gemm_init(gemm_reference, arguments_reference, workspace_reference); + } + + void debug_output(std::ostream & os) + { + auto print_tensor = [](std::ostream &os, char const * name, auto const & data, auto shape, auto stride) + { + std::vector> h_data(data.size()); + data.copy_to_host(h_data.data()); + Tensor t = make_tensor(h_data.data(), shape, stride); + os << "\n" << name << ": " << std::setw(4) << t << std::endl; + }; + auto [M,N,K,L] = problem_shape; + print_tensor(os, "A", tensor_a, make_shape(M,K,L), stride_A); + print_tensor(os, "B", tensor_b, make_shape(N,K,L), stride_B); + print_tensor(os, "C", tensor_c, make_shape(M,N,L), stride_C); + print_tensor(os, "D", tensor_d, make_shape(M,N,L), stride_D); + print_tensor(os, "D reference", tensor_d_reference, make_shape(M,N,L), stride_D); + } + + template + static float + run_gemm(Gemm &gemm) + { + GpuTimer timer; + if constexpr (DoTime) timer.start(); + cutlass::Status status = gemm.run(); + CUTLASS_CHECK(status); + if constexpr (DoTime) timer.stop(); + if constexpr (DoTime) return timer.elapsed_millis(); + else return 0; + } + + template + static float + run_permute(cutlass::device_memory::allocation const & input, + cutlass::device_memory::allocation & output, + Layout const& layout, + cutlass::KernelHardwareInfo const & hw_info) + { + auto idx = find_if(layout.stride(), [](auto x){ return not is_constant<1, decltype(x)>{}; }); + auto stride = get(layout.stride()); + + GpuTimer timer; + if constexpr (DoTime) timer.start(); + permute::kBatched, Permute>(input.get(), + output.get(), + size(take<0,2>(layout)), + static_cast(stride), + shape<2>(layout), + hw_info); + if constexpr (DoTime) timer.stop(); + if constexpr (DoTime) return timer.elapsed_millis(); + else return 0; + }; + + template + auto run_reference(Gemm2 &gemm) + { + float permute_time = 0.f; + if constexpr (DoPermuteA) { + auto orig_layout = make_original_layout(make_layout(shape_A, stride_A)); + permute_time += run_permute(tensor_a, tensor_a_permuted, orig_layout, hw_info); + } + if constexpr (DoPermuteB) { + auto orig_layout = make_original_layout(make_layout(shape_B, stride_B)); + permute_time += run_permute(tensor_b, tensor_b_permuted, select<1,0,2>(orig_layout), hw_info); + } + if constexpr (DoPermuteC) { + auto orig_layout = make_original_layout(make_layout(shape_C, stride_C)); + permute_time += run_permute(tensor_c, tensor_c_permuted, orig_layout, hw_info); + } + + float gemm_time = run_gemm(gemm); + + if constexpr (DoPermuteD) { + auto orig_layout = make_layout(shape_D, stride_D); + permute_time += run_permute(tensor_d_unpermuted, tensor_d_reference, orig_layout, hw_info); + } + + return cute::make_tuple(gemm_time, permute_time); + } + + bool verify() + { + run_gemm(gemm_permute); + run_reference(gemm_reference); + return cutlass::reference::device::BlockCompareEqual(tensor_d.get(), tensor_d_reference.get(), tensor_d.size()); + } + + bool run(Options const &options) + { + if (options.reference_check) { + if (!verify()) { + std::cout << "Failed validation" << std::endl; +#if 1 + debug_output(std::cout); +#endif + return false; + } + else { + std::cout << "Passed validation" << std::endl; + } + } + + // + // Run profiling loop + // + + auto const benchmark = [&](auto name, auto func) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + func(); + } + timer.stop(); + + double runtime = timer.elapsed_millis() / double(options.iterations); + double gflops = 2 * double(problem_size.product()) / 1e6 / runtime; // Two flops per multiply-add + + std::cout << name << ":\n"; + std::cout << " Runtime: " << runtime << " ms\n"; + std::cout << " GFLOPs: " << gflops << "\n"; + }; + + benchmark("Fused GEMM+permute", [&](){ run_gemm(gemm_permute); }); + benchmark("Unfused GEMM+permute", [&](){ run_reference(gemm_reference); }); + benchmark("Standalone GEMM only", [&](){ run_gemm(gemm_reference); }); + std::cout << "\n"; + + return true; + } +}; +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +} // namespace example + + +int main(int argc, char const **argv) +{ + bool notSupported = false; + + // CUDA 12 minimum required + if (__CUDACC_VER_MAJOR__ < 12) { + std::cerr << "This example requires CUDA Toolkit version 12 or later.\n"; + notSupported = true; + } + + cudaDeviceProp props; + CUDA_CHECK(cudaGetDeviceProperties(&props, 0)); + + if (props.major < 9) { + std::cerr << "This example requires a device with compute capability 90 or higher.\n"; + notSupported = true; + } + if (notSupported) { + return EXIT_SUCCESS; // Do not fail CI checks on unsupported systems + } +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + example::Options options; + options.parse(argc, argv); + + if (options.help) { + options.print_usage(std::cout) << "\n"; + return EXIT_SUCCESS; + } + + if (!options.valid()) { + std::cerr << "Invalid arguments." << "\n"; + return EXIT_FAILURE; + } + + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + using namespace cute; + + // Define the data types + using ElementA = cutlass::half_t; + using ElementB = cutlass::half_t; + using ElementC = cutlass::half_t; + using ElementD = cutlass::half_t; + using ElementAccumulator = float; + using ElementEpilogue = float; + + // M=64 for TMA epilogue + using TileShape = Shape<_128,_128,_64>; + + // Cluster launch with TMA multicast for better perf + using ClusterShape = Shape<_2,_2,_1>; + + bool result = true; + +#define COMPILE_ALL_EXAMPLES 0 + + // REGULAR GEMMS + + { + print("===================================================\n"); + print("Tensor A: RowMajor, Tensor4DPermute0213<8,16>\n"); + using Runner = example::ExampleRunner, + ElementB, cutlass::layout::RowMajor, cutlass::layout::NoPermute, + ElementC, cutlass::layout::RowMajor, cutlass::layout::NoPermute, + ElementD, cutlass::layout::RowMajor, cutlass::layout::NoPermute, + ElementAccumulator, ElementEpilogue, + TileShape, ClusterShape>; + Runner runner(options, hw_info); + result &= runner.run(options); + } +#if COMPILE_ALL_EXAMPLES + { + print("===================================================\n"); + print("Tensor A: ColumnMajor, Tensor4DPermute0213<8,16>\n"); + using Runner = example::ExampleRunner, + ElementB, cutlass::layout::RowMajor, cutlass::layout::NoPermute, + ElementC, cutlass::layout::RowMajor, cutlass::layout::NoPermute, + ElementD, cutlass::layout::RowMajor, cutlass::layout::NoPermute, + ElementAccumulator, ElementEpilogue, + TileShape, ClusterShape>; + Runner runner(options, hw_info); + result &= runner.run(options); + } + { + print("===================================================\n"); + print("Tensor B: RowMajor, Tensor4DPermute0213<8,16>\n"); + using Runner = example::ExampleRunner, + ElementC, cutlass::layout::RowMajor, cutlass::layout::NoPermute, + ElementD, cutlass::layout::RowMajor, cutlass::layout::NoPermute, + ElementAccumulator, ElementEpilogue, + TileShape, ClusterShape>; + Runner runner(options, hw_info); + result &= runner.run(options); + } +#endif + { + print("===================================================\n"); + print("Tensor B: ColumnMajor, Tensor4DPermute0213<8,16>\n"); + using Runner = example::ExampleRunner, + ElementC, cutlass::layout::RowMajor, cutlass::layout::NoPermute, + ElementD, cutlass::layout::RowMajor, cutlass::layout::NoPermute, + ElementAccumulator, ElementEpilogue, + TileShape, ClusterShape>; + Runner runner(options, hw_info); + result &= runner.run(options); + } + { + print("===================================================\n"); + print("Tensor D: RowMajor, Tensor4DPermute0213<8,16>\n"); + using Runner = example::ExampleRunner, + ElementAccumulator, ElementEpilogue, + TileShape, ClusterShape>; + Runner runner(options, hw_info); + result &= runner.run(options); + } +#if COMPILE_ALL_EXAMPLES + { + print("===================================================\n"); + print("Tensor D: ColumnMajor, Tensor4DPermute0213<8,16>\n"); + using Runner = example::ExampleRunner, + ElementAccumulator, ElementEpilogue, + TileShape, ClusterShape>; + Runner runner(options, hw_info); + result &= runner.run(options); + } +#endif + { + print("===================================================\n"); + print("Tensor A: RowMajor, Tensor5DPermute20314<16,8,4>\n"); + using Runner = example::ExampleRunner, + ElementB, cutlass::layout::RowMajor, cutlass::layout::NoPermute, + ElementC, cutlass::layout::RowMajor, cutlass::layout::NoPermute, + ElementD, cutlass::layout::RowMajor, cutlass::layout::NoPermute, + ElementAccumulator, ElementEpilogue, + TileShape, ClusterShape>; + Runner runner(options, hw_info); + result &= runner.run(options); + } +#if COMPILE_ALL_EXAMPLES + { + print("===================================================\n"); + print("Tensor A: ColumnMajor, Tensor5DPermute02413<16,8,4>\n"); + using Runner = example::ExampleRunner, + ElementB, cutlass::layout::RowMajor, cutlass::layout::NoPermute, + ElementC, cutlass::layout::RowMajor, cutlass::layout::NoPermute, + ElementD, cutlass::layout::RowMajor, cutlass::layout::NoPermute, + ElementAccumulator, ElementEpilogue, + TileShape, ClusterShape>; + Runner runner(options, hw_info); + result &= runner.run(options); + } +#endif + { + print("===================================================\n"); + print("Tensor D: RowMajor, Tensor5DPermute20314<16,8,4>\n"); + using Runner = example::ExampleRunner, + ElementAccumulator, ElementEpilogue, + TileShape, ClusterShape>; + Runner runner(options, hw_info); + result &= runner.run(options); + } +#if COMPILE_ALL_EXAMPLES + { + print("===================================================\n"); + print("Tensor D: ColumnMajor, Tensor5DPermute02413<16,8,4>\n"); + using Runner = example::ExampleRunner, + ElementAccumulator, ElementEpilogue, + TileShape, ClusterShape>; + Runner runner(options, hw_info); + result &= runner.run(options); + } +#endif + + // BATCHED GEMMS + + { + print("===================================================\n"); + print("Tensor A: RowMajor, Tensor4DPermuteBMM0213<8>\n"); + using Runner = example::ExampleRunner, + ElementB, cutlass::layout::RowMajor, cutlass::layout::NoPermute, + ElementC, cutlass::layout::RowMajor, cutlass::layout::NoPermute, + ElementD, cutlass::layout::RowMajor, cutlass::layout::NoPermute, + ElementAccumulator, ElementEpilogue, + TileShape, ClusterShape>; + Runner runner(options, hw_info); + result &= runner.run(options); + } + { + print("===================================================\n"); + print("Tensor D: RowMajor, Tensor4DPermuteBMM0213<8>\n"); + using Runner = example::ExampleRunner, + ElementAccumulator, ElementEpilogue, + TileShape, ClusterShape>; + Runner runner(options, hw_info); + result &= runner.run(options); + } +#if COMPILE_ALL_EXAMPLES + { + print("===================================================\n"); + print("Tensor A: ColumnMajor, Tensor4DPermuteBMM0321<8>\n"); + using Runner = example::ExampleRunner, + ElementB, cutlass::layout::RowMajor, cutlass::layout::NoPermute, + ElementC, cutlass::layout::RowMajor, cutlass::layout::NoPermute, + ElementD, cutlass::layout::RowMajor, cutlass::layout::NoPermute, + ElementAccumulator, ElementEpilogue, + TileShape, ClusterShape>; + Runner runner(options, hw_info); + result &= runner.run(options); + } + { + print("===================================================\n"); + print("Tensor D: RowMajor, Tensor4DPermuteBMM0321<8>\n"); + using Runner = example::ExampleRunner, + ElementAccumulator, ElementEpilogue, + TileShape, ClusterShape>; + Runner runner(options, hw_info); + result &= runner.run(options); + } +#endif + return result ? EXIT_SUCCESS : EXIT_FAILURE; +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) +} diff --git a/examples/53_hopper_gemm_permute/CMakeLists.txt b/examples/53_hopper_gemm_permute/CMakeLists.txt new file mode 100644 index 0000000000..dc70d95f6a --- /dev/null +++ b/examples/53_hopper_gemm_permute/CMakeLists.txt @@ -0,0 +1,36 @@ + +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + + +cutlass_example_add_executable( + 53_hopper_gemm_permute + 53_hopper_gemm_permute.cu + ) + diff --git a/examples/53_hopper_gemm_permute/permute_kernel.cuh b/examples/53_hopper_gemm_permute/permute_kernel.cuh new file mode 100644 index 0000000000..8abe70108d --- /dev/null +++ b/examples/53_hopper_gemm_permute/permute_kernel.cuh @@ -0,0 +1,92 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Simple permutation kernel implementation. +*/ + +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/tensor_view.h" +#include "cutlass/fast_math.h" +#include "cute/numeric/numeric_types.hpp" + +namespace example +{ + +/** + * Assumes column-major input (M mode is contiguous, N mode is strided). + * For row major, the inputs must be switched accordingly. +*/ +template +__global__ void +permute_kernel(Element const* __restrict__ input, + Element* __restrict__ output, + Permute permute, + int64_t num_elems, + cutlass::FastDivmod stride_divmod) +{ + // CUTLASS 2.x batched permute functions assume 0 batch stride for target tensor + Element const * input_b = input + blockIdx.z * num_elems; + Element * output_b = output + (Batched ? 0 : blockIdx.z * num_elems); + for (int64_t k = threadIdx.x + blockIdx.x * blockDim.x; k < num_elems; k += blockDim.x * gridDim.x) + { + int i, j; + stride_divmod(j, i, k); + output_b[permute(cutlass::PitchLinearCoord(i, j))] = input_b[i + j * stride_divmod.divisor]; + } +} + +template +void permute(Element const* input, + Element * output, + int64_t num_elems, + int stride, + int batch_count, + cutlass::KernelHardwareInfo const& hw_info) +{ + // Upcast to uint128_t data type + int factor = 128 / cutlass::sizeof_bits::value; + assert(stride % factor == 0); + int stride_upcast = stride/factor; + int64_t num_elems_upcast = num_elems / factor; + Permute permute_upcast(cutlass::PitchLinearCoord(stride_upcast, int(num_elems_upcast/stride_upcast)), stride_upcast); + + cutlass::FastDivmod stride_divmod(stride); + dim3 blocks(hw_info.sm_count, 1, batch_count); + permute_kernel<<>>(reinterpret_cast(input), + reinterpret_cast(output), + permute_upcast, + num_elems_upcast, + stride_upcast); +} + +} // namespace example diff --git a/examples/53_hopper_gemm_permute/permute_traits.hpp b/examples/53_hopper_gemm_permute/permute_traits.hpp new file mode 100644 index 0000000000..4c5baccac5 --- /dev/null +++ b/examples/53_hopper_gemm_permute/permute_traits.hpp @@ -0,0 +1,274 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Additional permutation information for the example. +*/ + +#include "cutlass/layout/permute.h" +#include "cutlass/gemm/gemm.h" + +namespace example +{ + +using namespace cute; + +// This struct is specialized below for different CUTLASS 2.x permutation ops +// to describe the operation in terms of target CuTe shape and stride order. +template +struct PermuteTraits {}; + +// Use X as a placeholder for shape division result +using X = Underscore; + +// Reshape a rank-2 shape into a multidimensional shape. +// Input: +// shape = (A, B, ...) +// target_shape = ((A1, ..., X, ..., Am), (B1, ..., X, ..., Bn), ...) +// Output: +// ((A1, ..., A/prod(A1..Am), ..., Am), (B1, ..., B/prod(B1..Bn), ..., Bn), ...) +template +constexpr auto +reshape(Shape const& shape, TargetShape const& target_shape) +{ + if constexpr (is_tuple::value) { + return cute::transform(shape, target_shape, [](auto && s, auto && t){ return reshape(s, t); }); + } + else { + auto idx = find_if(target_shape, [](auto x){ return is_underscore{}; }); + constexpr int I = decltype(idx)::value; + static_assert(I < tuple_size_v, "Each mode of TargetShape must contain a placeholder X"); + auto divisors = remove(target_shape); + assert(shape % product(divisors) == 0); + return replace(target_shape, shape / product(divisors)); + } +} + +// Given a tensor layout, compute a permutation layout consisting of: +// - sub-modes corresponding to the implied multidimensional shape of the source tensor +// - strides accounting for the permutation operation being performed +template +constexpr auto +make_permute_layout(Layout const& layout) { + static_assert(cute::rank(Shape{}) == 3, "Only rank-3 layouts are supported"); + if constexpr (Transpose) { + // Deal with tensor B by transposing appropriately before and after computing the permute layout. + // Its CuTe-canonical mode order is [N,K,L], while permute operations expect [row,col,batch]. + return select<1,0,2>(make_permute_layout(select<1,0,2>(layout))); + } + else { + if constexpr (cutlass::layout::is_trivial_permute) { + // Special case for NoPermute. Use a depth-2 layout for consistency with other permutations. + using ShapeProfile = tuple, tuple, tuple>; + return unflatten(layout, ShapeProfile{}); + } + else { + // Here's where the permutation layout is actually built + using ShapeProfile = typename PermuteTraits::ShapeProfile; + using StrideOrder = typename PermuteTraits::StrideOrder; + return make_ordered_layout(reshape(layout.shape(), ShapeProfile{}), StrideOrder{}); + } + } +} + +namespace detail +{ + +template +struct is_constant_pred { + template + constexpr auto operator()(T) { + return is_constant{}; + } +}; + +template +constexpr auto +inverse_impl(Permutation const & perm, seq) { + return cute::make_tuple(Int{})>{}...); +} + +} // namespace detail + +// Compute an inverse of a permutation represented as a tuple of cute::Int<> +template +constexpr auto +inverse(Permutation const & perm) { + auto flat_perm = flatten(perm); + return unflatten(detail::inverse_impl(flat_perm, tuple_seq{}), perm); +} + +template +using inverse_t = decltype(inverse(T{})); + +// Given a rank-2 layout of tensor that is assumed to have been permuted, +// compute the original rank-2 layout of the tensor prior to the permutation. +// This is needed to form the correct input to the standalone permutation kernel. +template +constexpr auto +make_original_layout(Layout const& layout) { + static_assert(cute::rank(Shape{}) == 3, "Only rank-3 layouts are supported"); + if constexpr (Transpose) { + // Deal with tensor B by transposing appropriately before and after computing the permute layout. + // Its CuTe-canonical mode order is [N,K,L], while permute operations expect [row,col,batch]. + return select<1,0,2>(make_original_layout(select<1,0,2>(layout))); + } + else { + using ShapeProfile = typename PermuteTraits::ShapeProfile; + auto re_shape = flatten(reshape(layout.shape(), ShapeProfile{})); + using IndexOrder = typename PermuteTraits::IndexOrder; + auto orig_shape = transform_leaf(IndexOrder{}, [&](auto i){ return get(re_shape); }); + using OrigOrder = conditional_t(), seq<0,1,2>, seq<1,0,2>>; + // print("Permuted shape: "); print(reshape(layout.shape(), ShapeProfile{})); print("\n"); + // print("Original shape: "); print(orig_shape); print("\n"); + return make_ordered_layout(product_each(orig_shape), OrigOrder{}); + } +} + +/////////////// Tensor4DPermute0213 //////////////////// + +template +struct PermuteTraits> +{ + static constexpr bool kBatched = false; + using ShapeProfile = Shape>, Shape,X>, Shape>; + using IndexOrder = Step, Step<_1,_3>, Step<_4>>; + using StrideOrder = inverse_t; // Step, Step<_1,_3>, Step<_4>>; +}; + +template +struct PermuteTraits> +{ + static constexpr bool kBatched = false; + using ShapeProfile = Shape>, Shape,X>, Shape>; + using IndexOrder = Step, Step<_1,_3>, Step<_4>>; + using StrideOrder = inverse_t; // Step, Step<_1,_3>, Step<_4>>; +}; + +template +struct PermuteTraits> +{ + static constexpr bool kBatched = false; + using ShapeProfile = Shape,X>, Shape>, Shape>; + using IndexOrder = Step, Step<_0,_2>, Step<_4>>; + using StrideOrder = Step, Step<_0,_2>, Step<_4>>; +}; + +template +struct PermuteTraits> +{ + static constexpr bool kBatched = false; + using ShapeProfile = Shape,X>, Shape>, Shape>; + using IndexOrder = Step, Step<_0,_2>, Step<_4>>; + using StrideOrder = Step, Step<_0,_2>, Step<_4>>; +}; + +/////////////// Tensor4DPermuteBMM0321 //////////////////// + +template +struct PermuteTraits> +{ + static constexpr bool kBatched = true; + using ShapeProfile = Shape, Shape, Shape,X>>; + using IndexOrder = Step, Step<_1>, Step<_3>>; + using StrideOrder = Step, Step<_2>, Step<_1,_3>>; +}; + +template +struct PermuteTraits> +{ + static constexpr bool kBatched = true; + using ShapeProfile = Shape>, Shape, Shape>; + using IndexOrder = Step, Step<_2>, Step<_1,_3>>; + using StrideOrder = Step, Step<_1>, Step<_3>>; +}; + +/////////////// Tensor4DPermuteBMM0213 //////////////////// + +template +struct PermuteTraits> +{ + static constexpr bool kBatched = true; + using ShapeProfile = Shape, Shape, Shape,X>>; + using IndexOrder = Step, Step<_1,_2>, Step<_3>>; + using StrideOrder = Step, Step<_0>, Step<_1,_3>>; +}; + +template +struct PermuteTraits> +{ + static constexpr bool kBatched = true; + using ShapeProfile = Shape, Shape>, Shape>; + using IndexOrder = Step, Step<_1>, Step<_2,_3>>; + using StrideOrder = Step, Step<_0,_2>, Step<_3>>; +}; + +/////////////// Tensor5DPermute02413 //////////////////// + +template +struct PermuteTraits> +{ + static constexpr bool kBatched = false; + using ShapeProfile = Shape>, Shape,Int,X>, Shape>; + using IndexOrder = Step, Step<_4,_1,_3>, Step<_5>>; + using StrideOrder = inverse_t; // Step, Step<_1,_4,_2>, Step<_5>>; +}; + +template +struct PermuteTraits> +{ + static constexpr bool kBatched = false; + using ShapeProfile = Shape>, Shape,Int>, Shape>; + using IndexOrder = Step, Step<_1,_4,_2>, Step<_5>>; + using StrideOrder = inverse_t; // Step, Step<_4,_1,_3>, Step<_5>>; +}; + +/////////////// Tensor5DPermute20314 //////////////////// + +template +struct PermuteTraits> +{ + static constexpr bool kBatched = false; + using ShapeProfile = Shape,X>, Shape,Int>, Shape>; + using IndexOrder = Step, Step<_3,_1,_4>, Step<_5>>; + using StrideOrder = Step, Step<_0,_2,_4>, Step<_5>>; +}; + +template +struct PermuteTraits> +{ + static constexpr bool kBatched = false; + using ShapeProfile = Shape>, Shape,Int>, Shape>; + using IndexOrder = Step, Step<_2,_4,_1>, Step<_5>>; + using StrideOrder = Step, Step<_0,_3,_1>, Step<_5>>; +}; + +} // namespace example diff --git a/examples/54_hopper_fp8_warp_specialized_gemm/54_hopper_fp8_warp_specialized_gemm.cu b/examples/54_hopper_fp8_warp_specialized_gemm/54_hopper_fp8_warp_specialized_gemm.cu new file mode 100644 index 0000000000..726f6d222a --- /dev/null +++ b/examples/54_hopper_fp8_warp_specialized_gemm/54_hopper_fp8_warp_specialized_gemm.cu @@ -0,0 +1,599 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Simple Hopper FP8 GEMM example using CUTLASS 3.0 APIs for NVIDIA Hopper architecture + + This example demonstrate a simple way to instantiate and run a FP8 GEMM using the new CUTLASS 3.0 + APIs on NVIDIA Hopper architecture. New features that will be showcased in this example are as follows: + + 1. NVIDIA Hopper architecture introduces a new series of tensor core instructions (GMMA) + which are more efficient than the Ampere tensor core instructions. + + 2. NVIDIA Hopper architecture includes new Tensor Memory Accelerator (TMA) unit to transfer large + blocks of data efficiently between global memory and shared memory. TMA also supports asynchronous + copies between thread blocks in a cluster. + + 3. This example uses the Warp Specialized kernel design (see /media/docs/efficient_gemm.md for details). + + 4. This example shows all important fusions used by FP8 gemm kernels, + i.e., scale factor for A, B, C, D tensor, the abs_max value of D tensor. + + 5. A simple way to tune the CTA rasterization direction and swizzle pattern of Hopper kernels. Both the + CTA rasterization direction and swizzle pattern impact cross-CTA locality of accesses. By tuning we can + improve performance. + + Examples: + + $ ./examples/54_hopper_fp8_warp_specialized_gemm/54_hopper_fp8_warp_specialized_gemm --m=2048 --n=2048 --k=2048 --rasterization=N --swizzle=2 +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler_params.h" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gett.hpp" + + +#include "helper.h" +#include "hopper_fp8_commandline.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// + +// A matrix configuration +using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = cutlass::float_e5m2_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// C matrix configuration +using ElementC = cutlass::float_e4m3_t; // Element type for C and D matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + +// D matrix configuration +using ElementD = ElementC; +using LayoutD = LayoutC; +constexpr int AlignmentD = AlignmentC; + +// Auxiliary matrix configuration and other fusion types +using ElementAux = ElementC; +using LayoutAux = LayoutC; +using ElementAmax = float; +using ElementBias = float; + +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ElementCompute = float; // Element type for epilogue computation +using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape = Shape<_64,_128,_128>; // Threadblock-level tile size +using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster +using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecialized; +using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; +using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; +using FusionOperation = cutlass::epilogue::fusion::ScaledLinCombPerRowBiasEltActAmaxAux< + LayoutAux, cutlass::epilogue::thread::ReLU, ElementD, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementC>; + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + TileShape, ClusterShape, + EpilogueTileType, + ElementAccumulator, ElementCompute, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage)) + >, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloop, + CollectiveEpilogue +>; + +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + +// Extract information from Gemm kernel. +using EpilogueOutputOp = typename Gemm::EpilogueOutputOp; +using ElementScalar = typename EpilogueOutputOp::ElementScalar; +using ElementAmax = typename EpilogueOutputOp::ElementAmax; +using ActivationFunctor = typename EpilogueOutputOp::ActivationFn; + +using StrideA = typename Gemm::GemmKernel::StrideA; +using StrideB = typename Gemm::GemmKernel::StrideB; +using StrideC = typename Gemm::GemmKernel::StrideC; +using StrideD = typename Gemm::GemmKernel::StrideD; +using StrideAux = StrideD; + +constexpr bool IsDFp8 = + cute::is_same_v or + cute::is_same_v; + +constexpr bool IsAuxFp8 = + cute::is_same_v or + cute::is_same_v; + +/// Initialization +StrideA stride_A; +StrideB stride_B; +StrideC stride_C; +StrideD stride_D; +StrideAux stride_aux; +uint64_t seed; + +cutlass::HostTensor tensor_A; +cutlass::HostTensor tensor_B; +cutlass::HostTensor tensor_C; +cutlass::HostTensor tensor_D; +cutlass::HostTensor tensor_ref_D; +cutlass::HostTensor tensor_aux; +cutlass::HostTensor tensor_ref_aux; + +using LayoutScalar = cutlass::layout::PackedVectorLayout; +cutlass::HostTensor scalar_alpha; +cutlass::HostTensor scalar_beta; +cutlass::HostTensor scale_A; +cutlass::HostTensor scale_B; +cutlass::HostTensor scale_C; +cutlass::HostTensor scale_D; +cutlass::HostTensor scale_aux; +cutlass::HostTensor abs_max_D; +cutlass::HostTensor reference_abs_max_D; +cutlass::HostTensor abs_max_aux; +cutlass::HostTensor reference_abs_max_aux; + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +using RasterOrderOptions = typename cutlass::gemm::kernel::detail::PersistentTileSchedulerSm90Params::RasterOrderOptions; + +/// Result structure +struct Result +{ + double avg_runtime_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + Result( + double avg_runtime_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess) + : + avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false) + {} + +}; + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_tensor( + cutlass::TensorView view, + uint64_t seed) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } + else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } + else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } + else { + scope_max = 8; + scope_min = -8; + } + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + + return true; +} + +/// Initialize operands to be used in the GEMM and reference GEMM +void initialize(const Options &options) { + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, options.l)); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.m, options.n, options.l)); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, options.l)); + stride_aux = stride_D; + + auto a_coord = cutlass::make_Coord(options.m * options.l, options.k); + auto c_coord = cutlass::make_Coord(options.m * options.l, options.n); + auto b_coord = cutlass::make_Coord(options.k, options.n * options.l); + + tensor_A.resize(a_coord); + tensor_B.resize(b_coord); + tensor_C.resize(c_coord); + tensor_D.resize(c_coord); + tensor_ref_D.resize(c_coord); + + initialize_tensor(tensor_A.host_view(), seed + 2022); + initialize_tensor(tensor_B.host_view(), seed + 2023); + initialize_tensor(tensor_C.host_view(), seed + 2024); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + + if (options.save_aux) { + tensor_aux.resize(c_coord); + tensor_aux.sync_device(); + tensor_ref_aux.resize(c_coord); + } + + if (options.device_scale) { + scalar_alpha.resize(cutlass::make_Coord(1)); + scalar_beta.resize(cutlass::make_Coord(1)); + scale_A.resize(cutlass::make_Coord(1)); + scale_B.resize(cutlass::make_Coord(1)); + scale_C.resize(cutlass::make_Coord(1)); + scale_D.resize(cutlass::make_Coord(1)); + scale_aux.resize(cutlass::make_Coord(1)); + + cutlass::reference::host::TensorFill(scalar_alpha.host_view(), options.alpha); + cutlass::reference::host::TensorFill(scalar_beta.host_view(), options.beta); + cutlass::reference::host::TensorFill(scale_A.host_view(), options.scale_a); + cutlass::reference::host::TensorFill(scale_B.host_view(), options.scale_b); + cutlass::reference::host::TensorFill(scale_C.host_view(), options.scale_c); + cutlass::reference::host::TensorFill(scale_D.host_view(), options.scale_d); + cutlass::reference::host::TensorFill(scale_aux.host_view(), options.scale_aux); + + scalar_alpha.sync_device(); + scalar_beta.sync_device(); + scale_A.sync_device(); + scale_B.sync_device(); + scale_C.sync_device(); + scale_D.sync_device(); + scale_aux.sync_device(); + } + + if (IsDFp8 && options.save_amax) { + abs_max_D.resize(cutlass::make_Coord(1)); + abs_max_D.sync_device(); + reference_abs_max_D.resize(cutlass::make_Coord(1)); + } + + if (IsAuxFp8 && options.save_aux && options.save_amax) { + abs_max_aux.resize(cutlass::make_Coord(1)); + abs_max_aux.sync_device(); + reference_abs_max_aux.resize(cutlass::make_Coord(1)); + } +} + +/// Populates a Gemm::Arguments structure from the given commandline options +typename Gemm::Arguments args_from_options(const Options &options) +{ + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {options.m, options.n, options.k, options.l}, + {tensor_A.device_data(), stride_A, tensor_B.device_data(), stride_B}, + { + {}, // epilogue.thread + tensor_C.device_data(), stride_C, + tensor_D.device_data(), stride_D + } + }; + + auto &fusion_args = arguments.epilogue.thread; + fusion_args.alpha = options.alpha; + fusion_args.beta = options.beta; + fusion_args.alpha_ptr = scalar_alpha.device_data(); + fusion_args.beta_ptr = scalar_beta.device_data(); + fusion_args.scale_a = options.scale_a; + fusion_args.scale_b = options.scale_b; + fusion_args.scale_c = options.scale_c; + fusion_args.scale_a_ptr = scale_A.device_data(); + fusion_args.scale_b_ptr = scale_B.device_data(); + fusion_args.scale_c_ptr = scale_C.device_data(); + + // ignored if tensor types are not fp8 + fusion_args.scale_d = options.scale_d; + fusion_args.scale_aux = options.scale_aux; + fusion_args.scale_d_ptr = scale_D.device_data(); + fusion_args.scale_aux_ptr = scale_aux.device_data(); + + // leaving/setting these as nullptr disables the fusion at runtime + fusion_args.bias_ptr = nullptr; + + if (options.save_aux) { + fusion_args.aux_ptr = tensor_aux.device_data(); + fusion_args.dAux = stride_aux; + if (options.save_amax) { + fusion_args.amax_aux_ptr = abs_max_aux.device_data(); + } + } + + if (options.save_amax) { + fusion_args.amax_D_ptr = abs_max_D.device_data(); + } + + arguments.scheduler.raster_order = options.raster; + // The tile scheduler will swizzle up to 8 and with the nearest multiple of 2 (i.e., 1, 2, 4, and 8) + arguments.scheduler.max_swizzle_size = options.swizzle; + + return arguments; +} + +bool verify(const Options &options) { + // + // Compute reference output + // + + // Create instantiation for device reference gemm kernel + auto A = cute::make_tensor(tensor_A.host_data(), + cute::make_layout(cute::make_shape(options.m, options.k, options.l), stride_A)); + auto B = cute::make_tensor(tensor_B.host_data(), + cute::make_layout(cute::make_shape(options.n, options.k, options.l), stride_B)); + auto C = cute::make_tensor(tensor_C.host_data(), + cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_C)); + auto D = cute::make_tensor(tensor_ref_D.host_data(), + cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_D)); + auto Aux = cute::make_tensor(tensor_ref_aux.host_data(), + cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_aux)); + using unused_t = decltype(D); + + cutlass::reference::host::GettMainloopParams mainloop_params{A, B}; + + cutlass::reference::host::GettEpilogueParams< + ElementScalar, + ElementScalar, + ElementAccumulator, + ElementCompute, + decltype(C), + decltype(D), + unused_t, // bias + decltype(Aux), + unused_t, // valpha + unused_t, // vbeta + ActivationFunctor + > epilogue_params; + + epilogue_params.C = C; + epilogue_params.D = D; + epilogue_params.Aux = Aux; + epilogue_params.alpha = options.alpha; + epilogue_params.beta = options.beta; + epilogue_params.scale_a = options.scale_a; + epilogue_params.scale_b = options.scale_b; + epilogue_params.scale_c = options.scale_c; + epilogue_params.scale_d = options.scale_d; + epilogue_params.scale_aux = options.scale_aux; + epilogue_params.abs_max_D = reference_abs_max_D.host_data(); + epilogue_params.abs_max_Aux = reference_abs_max_aux.host_data(); + + // get reference result + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + + // compare_reference + tensor_D.sync_host(); + bool passed = cutlass::reference::host::TensorEquals(tensor_ref_D.host_view(), tensor_D.host_view()); + + if (IsDFp8 && options.save_amax) { + abs_max_D.sync_host(); + passed &= abs_max_D.at(cutlass::make_Coord(0)) == reference_abs_max_D.at(cutlass::make_Coord(0)); + } + + if (options.save_aux) { + tensor_aux.sync_host(); + passed &= cutlass::reference::host::TensorEquals(tensor_ref_aux.host_view(), tensor_aux.host_view()); + if (IsAuxFp8 && options.save_amax) { + abs_max_aux.sync_host(); + passed &= abs_max_aux.at(cutlass::make_Coord(0)) == reference_abs_max_aux.at(cutlass::make_Coord(0)); + } + } + + return passed; +} + +/// Execute a given example GEMM computation +template +int run(Options &options) +{ + initialize(options); + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run()); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + result.passed = verify(options); + + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + + if (!result.passed) { + exit(-1); + } + + // Run profiling loop + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.run()); + } + timer.stop(); + + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + std::string raster = "Heuristic"; + + if (options.raster == RasterOrderOptions::AlongN) { + raster = "Along N"; + } + else if (options.raster == RasterOrderOptions::AlongM) { + raster = "Along M"; + } + + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; + std::cout << " Rasterization: " << raster << " with a maximum CTA swizzle of " << options.swizzle << std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << result.gflops << std::endl; + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example + // and must have compute capability at least 90. + if (__CUDACC_VER_MAJOR__ < 12) { + std::cerr << "This example requires CUDA 12 or newer.\n"; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (props.major < 9) { + std::cerr + << "This example requires a GPU of NVIDIA's Hopper Architecture or " + << "later (compute capability 90 or greater).\n"; + return 0; + } + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Evaluate CUTLASS kernels + // + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + run(options); +#endif + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/54_hopper_fp8_warp_specialized_gemm/CMakeLists.txt b/examples/54_hopper_fp8_warp_specialized_gemm/CMakeLists.txt new file mode 100644 index 0000000000..209b2779fe --- /dev/null +++ b/examples/54_hopper_fp8_warp_specialized_gemm/CMakeLists.txt @@ -0,0 +1,32 @@ +# Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +cutlass_example_add_executable( + 54_hopper_fp8_warp_specialized_gemm + 54_hopper_fp8_warp_specialized_gemm.cu + ) diff --git a/examples/54_hopper_fp8_warp_specialized_gemm/hopper_fp8_commandline.hpp b/examples/54_hopper_fp8_warp_specialized_gemm/hopper_fp8_commandline.hpp new file mode 100644 index 0000000000..96d8794d8e --- /dev/null +++ b/examples/54_hopper_fp8_warp_specialized_gemm/hopper_fp8_commandline.hpp @@ -0,0 +1,129 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// Command line options parsing +template +struct Options { + + bool help = false; + + float alpha = 1.f, beta = 0.f; + float scale_a = 1.f, scale_b = 1.f, scale_c = 1.f, scale_d = 1.f, scale_aux = 1.f; + bool device_scale = false; + bool save_aux = true; + bool save_amax = true; + int iterations = 1000; + int m = 1024, n = 512, k = 1024, l = 1; + RasterOrderOptions raster; + int swizzle; + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("l", l); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("scale_a", scale_a, 1.f); + cmd.get_cmd_line_argument("scale_b", scale_b, 1.f); + cmd.get_cmd_line_argument("scale_c", scale_c, 1.f); + cmd.get_cmd_line_argument("scale_d", scale_d, 1.f); + cmd.get_cmd_line_argument("scale_aux", scale_aux, 1.f); + cmd.get_cmd_line_argument("device_scale", device_scale, false); + cmd.get_cmd_line_argument("save_aux", save_aux, true); + cmd.get_cmd_line_argument("save_amax", save_amax, true); + cmd.get_cmd_line_argument("iterations", iterations); + + char raster_char; + cmd.get_cmd_line_argument("raster", raster_char); + + if (raster_char == 'N' || raster_char == 'n') { + raster = RasterOrderOptions::AlongN; + } + else if (raster_char == 'M' || raster_char == 'm') { + raster = RasterOrderOptions::AlongM; + } + else if (raster_char == 'H' || raster_char == 'h') { + raster = RasterOrderOptions::Heuristic; + } + + cmd.get_cmd_line_argument("swizzle", swizzle, 1); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "54_fp8_hopper_warp_specialized_gemm\n\n" + << " Hopper FP8 GEMM using a Warp Specialized kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the l extent (batch) of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n" + << " --scale_a= Scaling factor for A\n" + << " --scale_b= Scaling factor for B\n" + << " --scale_c= Scaling factor for C\n" + << " --scale_d= Scaling factor for D (ignored for non-fp8 D)\n" + << " --scale_aux= Scaling factor for the auxiliary tensor (ignored for non-fp8 aux)\n" + << " --device_scale= Copy scalars to device memory before kernel launch (default: false)\n" + << " --save_aux= Save the pre-activation as an auxiliary tensor (default: true)\n" + << " --save_amax= Save the pre-scaled max absolute value of any fp8 outputs (aux and/or D) (default: true)\n" + << " --raster= CTA Rasterization direction (N for along N, M for along M, and H for heuristic)\n\n" + << " --swizzle= CTA Rasterization swizzle\n\n" + << " --iterations= Number of profiling iterations to perform.\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "54_fp8_hopper_warp_specialized_gemm" << " --m=1024 --n=512 --k=1024 --alpha=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; diff --git a/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_bf16_gemm.cu b/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_bf16_gemm.cu new file mode 100644 index 0000000000..ab82b40cca --- /dev/null +++ b/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_bf16_gemm.cu @@ -0,0 +1,683 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Hopper GEMM example with different data types using CUTLASS 3.0 APIs for NVIDIA Hopper architecture + + This example shows how to perform INT4 x BF16 GEMM and scale up the INT4 weight during dequantization. + + The narrower type always passes through the register file. Therefore, in cases where the narrower type is operand B, the collective will implicitly swap + A and B in the main loop. However, as a result of this collective performing implicit swaps, it does not support TMA epilogues. Consequently, it is essential to consider this when constructing the epilogue, + as illustrated in this example. + + Note that in this example, we explicitly swap A and B in order to use TMA epilogues. We do this since TMA epilogues are more performant on problem sizes of interest. + + As an additional optimization, we can reorder the narrow data type tensor such that elements read into register file by the same thread are contiguous in global and shared memory. + This promotes vectorization of shared memory loads and removes additional instructions on the critical path. For example, when MMA is performed in 16-bit data type, each thread reads + 4 groups of 2 elements that are logically contiguous in the same row (refer to https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n16-a for thread-value layout). + If the narrow type is INT4 and tensor is major in K dim, only 8 bits can be read at a time, leading to extra load instructions and suboptimal utilization of shared memory throughput. + If we reorder the data offline to place all 16 elements read by a thread contiguously in memory, a single 64-bit load is sufficient. This reordering is often feasible when the quantized + tensor is static (e.g. weight tensor of a NN layer at inference time). This example demonstrates how such a reordering can be performed and communicated to the kernel when the options.shuffle is set to true. + + Furthermore, the conversion from {INT4, UINT4} to {FP16, BF16} can benefit from pre-shuffling the weights in the order [0,2,4,6,1,3,5,7]. This allows multiple nibbles to be efficiently extracted and up-converted + in parallel. The reordering is enabled by defining the layout type `ValueShuffle`. Refer to the partial specializations of `NumericArrayShuffleConverter` in "include/cutlass/detail/collective/mixed_input_utils.hpp" + for more details. + + It is expected that the scale's K dimension be scale_k = ceil_div(problem_k, group_size). + + Scales are always expected to be MN major. This means the fastest changing dimension must be M if A is scaled or N if B is scaled. + + If A is being scaled, the scales must have shape [M, scale_k], while if B is scaled, it must have shape [N, scale_k]. + + The implementation only supports "group-wise" scales. However, we can make it work for per-column scales by setting the group's size + equal to the gemm problem K. + + Limitations: + 1) The INT4 weights have additional encoding requirements. + 2) The scales must be MN major. That means if A is scaled, it must be column major, but if B is scaled it must be row major. + 3) The scales must have the same layout and groupsize. + 4) The groupsize must be greater or equal to the tile shape k. + 5) Currently, TMA epilogues cannot be used when the narrow type is the B operand. This limitation arises because the implementation always swaps the + operands to ensure that the narrow type passes through the register file, and TMA epilogues do not currently support implicit swap + transpose operations. + We plan to address this limitation in the future. However, we address this in the example by explicitly swapping and transposing the operands. + + Optimizing suggestions: + 1) Use a small tile size, since the register pressure for this GEMM (and RS GEMM in general) is high (it uses a lot of register space). + + Examples: + + Runs the mixed input batched gemm (with batch size 2), converting B to the type of A (mode 0) + $ ./examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_bf16_gemm --m=2048 --n=2048 --k=2048 --l=2 --mode=0 + + Runs the mixed input gemm, and applies a scaling factor to B before mma (mode 1). Applies a vector of scales to the entire + matrix (group size is the same as the gemm k dimension). + $ ./examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_bf16_gemm --m=4096 --n=5120 --k=8192 --g=8192 --mode=1 +*/ + +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_compare.h" + +#include "helper.h" +#include "mixed_dtype_utils.hpp" +#include "packed_scale.hpp" +#include "reorder_utils.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// +using MmaType = cutlass::bfloat16_t; +using QuantType = cutlass::int4b_t; +constexpr int TileShapeK = 128 * 8 / sizeof_bits::value; + +// A matrix configuration +using ElementA = MmaType; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = QuantType; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// This example manually swaps and transposes, so keep transpose of input layouts +using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose::type; +using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose::type; + +using StrideA = cutlass::detail::TagToStrideA_t; +using StrideB = cutlass::detail::TagToStrideB_t; + +// Define the CuTe layout for reoredered quantized tensor B +// LayoutAtomQuant places values that will be read by the same thread in contiguous locations in global memory. +// It specifies the reordering within a single warp's fragment +//using ValueShuffle = Layout<_1>; // no value reordering +using ValueShuffle = Layout, Stride<_4,_1>>; // order [0,2,4,6,1,3,5,7] +int constexpr NumShuffleAtoms = 1; +using MmaAtomShape = Layout>>; +using LayoutAtomQuant = decltype(compute_memory_reordering_atom()); +using LayoutB_Reordered = decltype(tile_to_shape(LayoutAtomQuant{}, Layout, StrideB>{})); + +using ElementScale = MmaType; +using ElementZero = ElementScale; +using LayoutScale = cutlass::layout::RowMajor; + +// C/D matrix configuration +using ElementC = cutlass::half_t; // Element type for C and D matrix operands +using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + +// D matrix configuration +using ElementD = ElementC; +using LayoutD = LayoutC; +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ElementCompute = float; // Element type for epilogue computation +using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape = Shape<_128,_128,cute::Int>; // Threadblock-level tile size +using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster +using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperative; // Kernel to launch based on the default setting in the Collective Builder +using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; +using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape, ClusterShape, + EpilogueTileType, + ElementAccumulator, ElementAccumulator, + // Transpose layout of D here since we use explicit swap + transpose + // the void type for C tells the builder to allocate 0 smem for the C matrix. + // We can enable this if beta == 0 by changing ElementC to void below. + ElementC, typename cutlass::layout::LayoutTranspose::type, AlignmentC, + ElementD, typename cutlass::layout::LayoutTranspose::type, AlignmentD, + EpilogueSchedule // This is the only epi supporting the required swap + transpose. + >::CollectiveOp; + +// ============================================================ MIXED INPUT NO SCALES ============================================================================ +// The collective will infer that the narrow type should be upcasted to the wide type. +// We swap A and B operands to the builder here +using CollectiveMainloopConvertOnly = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementB, LayoutB_Transpose, AlignmentB, + ElementA, LayoutA_Transpose, AlignmentA, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage)) + >, + KernelSchedule + >::CollectiveOp; + +using GemmKernelConvertOnly = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloopConvertOnly, + CollectiveEpilogue +>; + +using GemmConvertOnly = cutlass::gemm::device::GemmUniversalAdapter; + +using CollectiveMainloopConvertOnlyShuffled = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementB, LayoutB_Reordered, AlignmentB, + ElementA, LayoutA_Transpose, AlignmentA, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage)) + >, + KernelSchedule + >::CollectiveOp; + +using GemmKernelConvertOnlyShuffled = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloopConvertOnlyShuffled, + CollectiveEpilogue +>; + +using GemmConvertOnlyShuffled = cutlass::gemm::device::GemmUniversalAdapter; + +// =========================================================== MIXED INPUT WITH SCALES =========================================================================== +// The Scale information must get paired with the operand that will be scaled. In this example, B is scaled so we make a tuple of B's information and the scale information. +using CollectiveMainloopScaleOnly = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + cute::tuple, LayoutB_Transpose, AlignmentB, + ElementA, LayoutA_Transpose, AlignmentA, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage)) + >, + KernelSchedule + >::CollectiveOp; + +using GemmKernelScaleOnly = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloopScaleOnly, + CollectiveEpilogue +>; + +using GemmScaleOnly = cutlass::gemm::device::GemmUniversalAdapter; + +using CollectiveMainloopScaleOnlyShuffled = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + cute::tuple, LayoutB_Reordered, AlignmentB, + ElementA, LayoutA_Transpose, AlignmentA, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage)) + >, + KernelSchedule + >::CollectiveOp; + +using GemmKernelScaleOnlyShuffled = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloopScaleOnlyShuffled, + CollectiveEpilogue +>; + +using GemmScaleOnlyShuffled = cutlass::gemm::device::GemmUniversalAdapter; + +// =========================================================== MIXED INPUT WITH SCALES AND ZEROS ================================================================== +// We specify scale + zero elements to indicate that we require both. Scales and biases have the same format. +using CollectiveMainloopScaleWithZeroPoint = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + cute::tuple, LayoutB_Transpose, AlignmentB, + ElementA, LayoutA_Transpose, AlignmentA, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage)) + >, + KernelSchedule + >::CollectiveOp; + +using GemmKernelScaleWithZeroPoint = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloopScaleWithZeroPoint, + CollectiveEpilogue +>; + +using GemmScaleWithZeroPoint = cutlass::gemm::device::GemmUniversalAdapter; + +using CollectiveMainloopScaleWithZeroPointShuffled = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + cute::tuple, LayoutB_Reordered, AlignmentB, + ElementA, LayoutA_Transpose, AlignmentA, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage)) + >, + KernelSchedule + >::CollectiveOp; + +using GemmKernelScaleWithZeroPointShuffled = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloopScaleWithZeroPointShuffled, + CollectiveEpilogue +>; + +using GemmScaleWithZeroPointShuffled = cutlass::gemm::device::GemmUniversalAdapter; +// ================================================================================================================================================================= + +using StrideC = typename GemmKernelScaleOnly::StrideC; +using StrideD = typename GemmKernelScaleOnly::StrideD; + +using StrideC_ref = cutlass::detail::TagToStrideC_t; +using StrideD_ref = cutlass::detail::TagToStrideC_t; + +// +// Data members +// + +/// Initialization +StrideA stride_A; +StrideB stride_B; +StrideC stride_C; +StrideC_ref stride_C_ref; +StrideD stride_D; +StrideD_ref stride_D_ref; +uint64_t seed; + +LayoutB_Reordered layout_B_reordered; + +using StrideS = typename CollectiveMainloopScaleOnly::StrideScale; +using StrideS_ref = cutlass::detail::TagToStrideB_t; +StrideS stride_S; +StrideS_ref stride_S_ref; + +cutlass::DeviceAllocation block_A; +cutlass::DeviceAllocation block_B; +cutlass::DeviceAllocation block_B_dq; +cutlass::DeviceAllocation block_scale; +cutlass::DeviceAllocation block_zero; +cutlass::DeviceAllocation block_C; +cutlass::DeviceAllocation block_D; +cutlass::DeviceAllocation block_ref_D; + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options : MixedDtypeOptions{ + bool shuffle = true; + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + cmd.get_cmd_line_argument("shuffle", shuffle); + + this->MixedDtypeOptions::parse(argc, args); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "55_hopper_int4_bf16_gemm\n\n" + << " Hopper Mixed Data Type GEMM using a Warp Specialized kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= The number of independent gemm problems with mnk shape\n" + << " --g= The size of each group for the scales. To broadcast a vector of scales or zeros, set the group size to K.\n" + << " --mode= The mode to run the gemm. 0 does (A @ B), 1 means A @ (scale * B), 2 means A @ (scale * B + zero-point).\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Number of profiling iterations to perform.\n\n" + << " --warmup= Number of warmup iterations to perform.\n\n" + << " --shuffle= Enable the offline layout swizzling.\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "55_hopper_int4_bf16_gemm" << " --m=1024 --n=512 --k=1024 -g=1024 --l=10 --alpha=2 --mode=2 --beta=0.707 \n\n"; + + return out; + } +}; + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Initialize operands to be used in the GEMM and reference GEMM +void initialize(Options const& options) { + + auto shape_B = cute::make_shape(options.n, options.k, options.l); + int const scale_k = (options.k + options.g - 1) / options.g; + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_B); + // Reverse stride here due to swap and transpose + stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.n, options.m, options.l)); + stride_C_ref = cutlass::make_cute_packed_stride(StrideC_ref{}, cute::make_shape(options.m, options.n, options.l)); + // Reverse stride here due to swap and transpose + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.n, options.m, options.l)); + stride_D_ref = cutlass::make_cute_packed_stride(StrideD_ref{}, cute::make_shape(options.m, options.n, options.l)); + + auto layout_B = make_layout(shape_B, stride_B); + + auto a_coord = cutlass::make_Coord(options.m * options.l, options.k); + auto b_coord = cutlass::make_Coord(options.k, options.n * options.l); + auto c_coord = cutlass::make_Coord(options.m * options.l, options.n); + + block_A.reset(a_coord.product()); + block_B.reset(b_coord.product()); + block_B_dq.reset(b_coord.product()); + block_C.reset(c_coord.product()); + block_D.reset(c_coord.product()); + block_ref_D.reset(c_coord.product()); + + block_scale.reset(scale_k * options.l * options.n); + block_zero.reset(scale_k * options.l * options.n); + + initialize_tensor(block_A, seed + 2022); + initialize_quant_tensor(block_B, seed + 2021); + initialize_tensor(block_C, seed + 2020); + initialize_scale(block_scale, options); + initialize_zero(block_zero, options); + + auto shape_scale_zero = cute::make_shape(options.n, scale_k, options.l); + stride_S = cutlass::make_cute_packed_stride(StrideS{}, cute::make_shape(options.n, scale_k, options.l)); + stride_S_ref = cutlass::make_cute_packed_stride(StrideS_ref{}, cute::make_shape(options.n, scale_k, options.l)); + auto layout_scale_zero = make_layout(shape_scale_zero, stride_S_ref); + + dequantize_weight(block_B_dq.get(), block_B.get(), layout_B, block_scale.get(), block_zero.get(), layout_scale_zero, options.g); + + if (options.shuffle) { + // Repeat the reorder layout atom to tile the whole tensor shape + layout_B_reordered = tile_to_shape(LayoutAtomQuant{}, shape_B); + reorder_tensor(block_B.get(), layout_B, layout_B_reordered); + + print("Quantized tensor layout: "); + print(layout_B_reordered); + print("\n"); + } +} + +/// Populates a Gemm::Arguments structure from the given commandline options +/// Swap the A and B tensors, as well as problem shapes here. +template +typename Gemm::Arguments args_from_options(Options const& options) +{ + using Args = typename Gemm::Arguments; + auto&& dB = [&]() { + if constexpr (cute::is_same_v || + cute::is_same_v || + cute::is_same_v) { + // offline swizzling is enabled. + return layout_B_reordered; + } + else { + return stride_B; + } + }(); + if (options.mode == MixedDtypeGemmMode::ConvertOnly) { + return Args { + cutlass::gemm::GemmUniversalMode::kGemm, + {options.n, options.m, options.k, options.l}, + {block_B.get(), dB, block_A.get(), stride_A}, + {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D} + }; + } + else if (options.mode == MixedDtypeGemmMode::ScaleOnly) { + return Args { + cutlass::gemm::GemmUniversalMode::kGemm, + {options.n, options.m, options.k, options.l}, + {block_B.get(), dB, block_A.get(), stride_A, block_scale.get(), stride_S, options.g}, + {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D} + }; + } + else if (options.mode == MixedDtypeGemmMode::ScaleWithZeroPoint) { + return Args { + cutlass::gemm::GemmUniversalMode::kGemm, + {options.n, options.m, options.k, options.l}, + {block_B.get(), dB, block_A.get(), stride_A, block_scale.get(), stride_S, options.g, block_zero.get()}, + {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D} + }; + } else { + std::cerr << "Invalid mode " << options.mode << ". Must be 0, 1 or 2." << std::endl; + exit(-1); + } +} + +bool verify(Options const& options) { + // + // Compute reference output + // + + using CollectiveMainloopRef = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + MmaType, LayoutA, AlignmentA, + MmaType, LayoutB, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAuto, + cutlass::gemm::collective::KernelScheduleAuto + >::CollectiveOp; + + using CollectiveEpilogueRef = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; + + using GemmKernelRef = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloopRef, + CollectiveEpilogueRef + >; + + using GemmRef = cutlass::gemm::device::GemmUniversalAdapter; + + typename GemmRef::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {options.m, options.n, options.k, options.l}, + {block_A.get(), stride_A, block_B_dq.get(), stride_B}, + {{options.alpha, options.beta}, block_C.get(), stride_C_ref, block_ref_D.get(), stride_D_ref} + }; + + // Run the gemm where the scaling is performed outside of the kernel. + GemmRef gemm_ref; + size_t workspace_size = GemmRef::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + CUTLASS_CHECK(gemm_ref.can_implement(arguments)); + CUTLASS_CHECK(gemm_ref.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm_ref.run()); + + // compare_reference + ElementD const epsilon(1e-2f); + ElementD const non_zero_floor(1e-4f); + bool passed = cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get(), block_D.get(), block_D.size(), epsilon, non_zero_floor); + + return passed; +} + +/// Execute a given example GEMM computation +template +int run(Options &options) +{ + initialize(options); + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run()); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + MixedDtypeResult result; + result.passed = verify(options); + mixed_dtype_profiling(gemm, options, result); + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + if (!result.passed) { + exit(-1); + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example + // and must have compute capability at least 90. + if (__CUDACC_VER_MAJOR__ < 12) { + std::cerr << "This example requires CUDA 12 or newer.\n"; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (props.major < 9) { + std::cerr + << "This example requires a GPU of NVIDIA's Hopper Architecture or " + << "later (compute capability 90 or greater).\n"; + return 0; + } + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Evaluate CUTLASS kernels + // + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + if (options.mode == MixedDtypeGemmMode::ConvertOnly) { + std::cout << "Running in no scale mode." << std::endl; + if (options.shuffle) { + std::cout << "Offline shuffle enabled." << std::endl; + run(options); + } else { + std::cout << "Offline shuffle disabled." << std::endl; + run(options); + } + } + else if (options.mode == MixedDtypeGemmMode::ScaleOnly) { + if (options.g == options.k) { + std::cout << "Running in per-column scale mode." << std::endl; + } else { + std::cout << "Running in group scale mode." << std::endl; + } + if (options.shuffle) { + std::cout << "Offline shuffle enabled." << std::endl; + run(options); + } else { + std::cout << "Offline shuffle disabled." << std::endl; + run(options); + } + } + else if (options.mode == MixedDtypeGemmMode::ScaleWithZeroPoint) { + if (options.g == options.k) { + std::cout << "Running in per-column scale and zero mode." << std::endl; + } else { + std::cout << "Running in group scale and zero mode." << std::endl; + } + if (options.shuffle) { + std::cout << "Offline shuffle enabled." << std::endl; + run(options); + } else { + std::cout << "Offline shuffle disabled." << std::endl; + run(options); + } + } +#endif + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu b/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu new file mode 100644 index 0000000000..40fa689489 --- /dev/null +++ b/examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm.cu @@ -0,0 +1,562 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Hopper GEMM example with different data types using CUTLASS 3.0 APIs for NVIDIA Hopper architecture + + This example shows how to perform INT4 x FP8 GEMM and scale up the INT4 weight during dequantization. It uses a look-up table to avoid the multiplications + between INT4 and FP8. To trigger this method, use cutlass::Array as the scale type in the collective's arguments. + + However, this algorithm requires changes to the encoding of INT4 weights and scale factors. These changes must happen before launching the GEMM. See the helper functions + `unify_quant_encoding`, `initialize_packed_scale` in the header `fp8_packed_scale.hpp` for details. + + In a nutshell, the positive values of INT4 weights need to be encoded in the same way as negative values except for the sign bit. For each scale factor, + 8 negative results (-8 x scale, -7 x scale, ... -1 x scale) are packed together, forming a cutlass::Array value. + + The narrower type always passes through the register file. Therefore, in cases where the narrower type is operand B, the collective will implicitly swap + A and B in the main loop. However, as a result of this collective performing implicit swaps, it does not support TMA epilogues. Consequently, it is essential to consider this when constructing the epilogue, + as illustrated in this example. + + Note that in this example, we explicitly swap A and B in order to use TMA epilogues. We do this since TMA epilogues are more performant on problem sizes of interest. + + As an additional optimization, we can reorder the narrow data type tensor such that elements read into register file by the same thread are contiguous in global and shared memory. + This promotes vectorization of shared memory loads and removes additional instructions on the critical path. For example, when MMA is performed in FP8 data type, each thread reads + 4 groups of 4 elements that are logically contiguous in the same row (refer to https://docs.nvidia.com/cuda/parallel-thread-execution/index.html#wgmma-64n32-a for thread-value layout). + If the narrow type is INT4 and tensor is major in K dim, only 16 bits can be read at a time, leading to extra load instructions and suboptimal utilization of shared memory throughput. + If we reorder the data offline to place all 16 elements read by a thread contiguously in memory, a single 64-bit load is sufficient. This reordering is often feasible when the quantized + tensor is static (e.g. weight tensor of a NN layer at inference time). This example demonstrates how such a reordering can be performed and communicated to the kernel when the options.shuffle is set to true. + + It is expected that the scale's K dimension be scale_k = ceil_div(problem_k, group_size). + + Scales are always expected to be MN major. This means the fastest changing dimension must be M if A is scaled or N if B is scaled. + + If A is being scaled, the scales must have shape [M, scale_k], while if B is scaled, it must have shape [N, scale_k]. + + The implementation only supports "group-wise" scales. However, we can make it work for per-column scales by setting the group's size + equal to the gemm problem K. + + Limitations: + 1) Only supports INT4 x { FP8, INT8, UINT8 }. The scales must be the same as mma Type. Scale with zero-point mode is not supported. + 2) The INT4 weights and scale factors have additional encoding requirements. + 3) The scales must be MN major. That means if A is scaled, it must be column major, but if B is scaled it must be row major. + 4) The scales must have the same layout and groupsize. + 5) The groupsize must be greater or equal to the tile shape k. + 6) Currently, TMA epilogues cannot be used when the narrow type is the B operand. This limitation arises because the implementation always swaps the + operands to ensure that the narrow type passes through the register file, and TMA epilogues do not currently support implicit swap + transpose operations. + We plan to address this limitation in the future. However, we address this in the example by explicitly swapping and transposing the operands. + + Optimizing suggestions: + 1) Use a small tile size, since the register pressure for this GEMM (and RS GEMM in general) is high (it uses a lot of register space). + + Examples: + + Runs the mixed input batched gemm (with batch size 2), converting B to the type of A (mode 0) + $ ./examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm --m=2048 --n=2048 --k=2048 --l=2 --mode=0 + + Runs the mixed input gemm, and applies a scaling factor to B before mma (mode 1). Applies a vector of scales to the entire + matrix (group size is the same as the gemm k dimension). + $ ./examples/55_hopper_mixed_dtype_gemm/55_hopper_int4_fp8_gemm --m=4096 --n=5120 --k=8192 --g=8192 --mode=1 +*/ + +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_compare.h" + +#include "helper.h" +#include "mixed_dtype_utils.hpp" +#include "packed_scale.hpp" +#include "reorder_utils.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// +using MmaType = cutlass::float_e4m3_t; +using QuantType = cutlass::int4b_t; +constexpr int TileShapeK = 128 * 8 / sizeof_bits::value; + +// A matrix configuration +using ElementA = MmaType; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = QuantType; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// This example manually swaps and transposes, so keep transpose of input layouts +using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose::type; +using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose::type; + +using StrideA = cutlass::detail::TagToStrideA_t; +using StrideB = cutlass::detail::TagToStrideB_t; + +// Define the CuTe layout for reoredered quantized tensor B +// LayoutAtomQuant places values that will be read by the same thread in contiguous locations in global memory. +// It specifies the reordering within a single warp's fragment +using LayoutAtomQuant = decltype(compute_memory_reordering_atom()); +using LayoutB_Reordered = decltype(tile_to_shape(LayoutAtomQuant{}, Layout, StrideB>{})); + +using ElementScale = MmaType; +using ElementZero = ElementScale; // only for verify +using LayoutScale = cutlass::layout::RowMajor; + +// C/D matrix configuration +using ElementC = cutlass::half_t; // Element type for C and D matrix operands +using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + +// D matrix configuration +using ElementD = ElementC; +using LayoutD = LayoutC; +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ElementCompute = float; // Element type for epilogue computation +using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape = Shape<_128,_128,cute::Int>; // Threadblock-level tile size +using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster +using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperative; // Kernel to launch based on the default setting in the Collective Builder +using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; +using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape, ClusterShape, + EpilogueTileType, + ElementAccumulator, ElementAccumulator, + // Transpose layout of D here since we use explicit swap + transpose + // the void type for C tells the builder to allocate 0 smem for the C matrix. + // We can enable this if beta == 0 by changing ElementC to void below. + ElementC, typename cutlass::layout::LayoutTranspose::type, AlignmentC, + ElementD, typename cutlass::layout::LayoutTranspose::type, AlignmentD, + EpilogueSchedule // This is the only epi supporting the required swap + transpose. + >::CollectiveOp; + +// =========================================================== MIXED INPUT WITH SCALES =========================================================================== +// The Scale information must get paired with the operand that will be scaled. In this example, B is scaled so we make a tuple of B's information and the scale information. +using CollectiveMainloopScaleOnly = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + cute::tuple>, LayoutB_Transpose, AlignmentB, + ElementA, LayoutA_Transpose, AlignmentA, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage)) + >, + KernelSchedule + >::CollectiveOp; + +using GemmKernelScaleOnly = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloopScaleOnly, + CollectiveEpilogue +>; + +using CollectiveMainloopShuffled = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + cute::tuple>, LayoutB_Reordered, AlignmentB, + ElementA, LayoutA_Transpose, AlignmentA, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage)) + >, + KernelSchedule + >::CollectiveOp; + +using GemmKernelShuffled = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloopShuffled, + CollectiveEpilogue +>; + +using GemmScaleOnly = cutlass::gemm::device::GemmUniversalAdapter; +using GemmShuffled = cutlass::gemm::device::GemmUniversalAdapter; + +using StrideC = typename GemmKernelScaleOnly::StrideC; +using StrideD = typename GemmKernelScaleOnly::StrideD; + +using StrideC_ref = cutlass::detail::TagToStrideC_t; +using StrideD_ref = cutlass::detail::TagToStrideC_t; + +// +// Data members +// + +/// Initialization +StrideA stride_A; +StrideB stride_B; +StrideC stride_C; +StrideC_ref stride_C_ref; +StrideD stride_D; +StrideD_ref stride_D_ref; +uint64_t seed; + +LayoutB_Reordered layout_B_reordered; + +using StrideS = typename CollectiveMainloopScaleOnly::StrideScale; +using StrideS_ref = cutlass::detail::TagToStrideB_t; +StrideS stride_S; +StrideS_ref stride_S_ref; + +cutlass::DeviceAllocation block_A; +cutlass::DeviceAllocation block_B; +cutlass::DeviceAllocation block_B_modified; +cutlass::DeviceAllocation block_B_dq; +cutlass::DeviceAllocation block_scale; +cutlass::DeviceAllocation> block_scale_packed; +cutlass::DeviceAllocation block_zero; +cutlass::DeviceAllocation block_C; +cutlass::DeviceAllocation block_D; +cutlass::DeviceAllocation block_ref_D; + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options : MixedDtypeOptions { + bool shuffle = true; + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + cmd.get_cmd_line_argument("shuffle", shuffle); + + this->MixedDtypeOptions::parse(argc, args); + + mode = 1; // override the mode value to always be scale only mode + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "55_hopper_int4_fp8_gemm\n\n" + << " Hopper Mixed Data Type GEMM using a Warp Specialized kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= The number of independent gemm problems with mnk shape\n" + << " --g= The size of each group for the scales. To broadcast a vector of scales or zeros, set the group size to K.\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Number of profiling iterations to perform.\n\n" + << " --warmup= Number of warmup iterations to perform.\n\n" + << " --shuffle= Enable the offline layout swizzling.\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "55_hopper_int4_fp8_gemm" << " --m=1024 --n=512 --k=1024 -g=1024 --l=10 --alpha=2 --beta=0.707 \n\n"; + + return out; + } +}; + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Initialize operands to be used in the GEMM and reference GEMM +void initialize(Options const& options) { + + auto shape_B = cute::make_shape(options.n, options.k, options.l); + int const scale_k = (options.k + options.g - 1) / options.g; + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_B); + // Reverse stride here due to swap and transpose + stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.n, options.m, options.l)); + stride_C_ref = cutlass::make_cute_packed_stride(StrideC_ref{}, cute::make_shape(options.m, options.n, options.l)); + // Reverse stride here due to swap and transpose + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.n, options.m, options.l)); + stride_D_ref = cutlass::make_cute_packed_stride(StrideD_ref{}, cute::make_shape(options.m, options.n, options.l)); + + auto layout_B = make_layout(shape_B, stride_B); + + auto a_coord = cutlass::make_Coord(options.m * options.l, options.k); + auto b_coord = cutlass::make_Coord(options.k, options.n * options.l); + auto c_coord = cutlass::make_Coord(options.m * options.l, options.n); + + block_A.reset(a_coord.product()); + block_B.reset(b_coord.product()); + block_B_modified.reset(b_coord.product()); + block_B_dq.reset(b_coord.product()); + block_C.reset(c_coord.product()); + block_D.reset(c_coord.product()); + block_ref_D.reset(c_coord.product()); + + block_scale.reset(scale_k * options.l * options.n); + block_scale_packed.reset(scale_k * options.l * options.n); + block_zero.reset(scale_k * options.l * options.n); + + initialize_tensor(block_A, seed + 2022); + initialize_quant_tensor(block_B, seed + 2021); + unify_quant_encoding(block_B, block_B_modified); + initialize_tensor(block_C, seed + 2020); + initialize_scale(block_scale, options); + initialize_packed_scale(block_scale, block_scale_packed); + initialize_zero(block_zero, options); + + auto shape_scale_zero = cute::make_shape(options.n, scale_k, options.l); + stride_S = cutlass::make_cute_packed_stride(StrideS{}, cute::make_shape(options.n, scale_k, options.l)); + stride_S_ref = cutlass::make_cute_packed_stride(StrideS_ref{}, cute::make_shape(options.n, scale_k, options.l)); + auto layout_scale_zero = make_layout(shape_scale_zero, stride_S_ref); + + dequantize_weight(block_B_dq.get(), block_B.get(), layout_B, block_scale.get(), block_zero.get(), layout_scale_zero, options.g); + + if (options.shuffle) { + // Repeat the reorder layout atom to tile the whole tensor shape + layout_B_reordered = tile_to_shape(LayoutAtomQuant{}, shape_B); + reorder_tensor(block_B_modified.get(), layout_B, layout_B_reordered); + + print("Quantized tensor layout: "); + print(layout_B_reordered); + print("\n"); + } +} + +/// Populates a Gemm::Arguments structure from the given commandline options +/// Swap the A and B tensors, as well as problem shapes here. +template +typename Gemm::Arguments args_from_options(Options const& options) +{ + using Args = typename Gemm::Arguments; + auto&& dB = [&]() { + if constexpr (cute::is_same_v) { // offline swizzling is enabled. + return layout_B_reordered; + } + else { + return stride_B; + } + }(); + return Args { + cutlass::gemm::GemmUniversalMode::kGemm, + {options.n, options.m, options.k, options.l}, + {block_B_modified.get(), dB, block_A.get(), stride_A, block_scale_packed.get(), stride_S, options.g}, + {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D} + }; +} + +bool verify(Options const& options) { + // + // Compute reference output + // + + // In this example, we use the GPU default kernels as a reference (unfused scale). + // This avoids numerical differences due to different accumulation order. + + // Again, due to numerical differences, we must use fast acc here when the mma type is + // FP8 as the fused implementation only supports fast acc at the moment. + constexpr bool IsFP8Input = cute::is_same_v || cute::is_same_v; + using FP8Sched = cute::conditional_t(TileShape{}) == 64, cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum, cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum>; + using ScheduleRef = cute::conditional_t; + + using CollectiveMainloopRef = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + MmaType, LayoutA, AlignmentA, + MmaType, LayoutB, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAuto, + ScheduleRef + >::CollectiveOp; + + using CollectiveEpilogueRef = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; + + using GemmKernelRef = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloopRef, + CollectiveEpilogueRef + >; + + using GemmRef = cutlass::gemm::device::GemmUniversalAdapter; + + typename GemmRef::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {options.m, options.n, options.k, options.l}, + {block_A.get(), stride_A, block_B_dq.get(), stride_B}, + {{options.alpha, options.beta}, block_C.get(), stride_C_ref, block_ref_D.get(), stride_D_ref} + }; + + // Run the gemm where the scaling is performed outside of the kernel. + GemmRef gemm_ref; + size_t workspace_size = GemmRef::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + CUTLASS_CHECK(gemm_ref.can_implement(arguments)); + CUTLASS_CHECK(gemm_ref.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm_ref.run()); + + // compare_reference + ElementD const epsilon(1e-2f); + ElementD const non_zero_floor(1e-4f); + bool passed = cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get(), block_D.get(), block_D.size(), epsilon, non_zero_floor); + + return passed; +} + +/// Execute a given example GEMM computation +template +int run(Options &options) +{ + initialize(options); + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run()); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + MixedDtypeResult result; + result.passed = verify(options); + mixed_dtype_profiling(gemm, options, result); + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + if (!result.passed) { + exit(-1); + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example + // and must have compute capability at least 90. + if (__CUDACC_VER_MAJOR__ < 12) { + std::cerr << "This example requires CUDA 12 or newer.\n"; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (props.major < 9) { + std::cerr + << "This example requires a GPU of NVIDIA's Hopper Architecture or " + << "later (compute capability 90 or greater).\n"; + return 0; + } + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Evaluate CUTLASS kernels + // + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + if (options.g == options.k) { + std::cout << "Running in per-column scale mode." << std::endl; + } else { + std::cout << "Running in group scale mode." << std::endl; + } + if (options.shuffle) { + std::cout << "Offline shuffle enabled." << std::endl; + run(options); + } else { + std::cout << "Offline shuffle disabled." << std::endl; + run(options); + } +#endif + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/55_hopper_mixed_dtype_gemm/55_hopper_mixed_dtype_gemm.cu b/examples/55_hopper_mixed_dtype_gemm/55_hopper_mixed_dtype_gemm.cu new file mode 100644 index 0000000000..b482d0d15f --- /dev/null +++ b/examples/55_hopper_mixed_dtype_gemm/55_hopper_mixed_dtype_gemm.cu @@ -0,0 +1,535 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Hopper GEMM example with different data types using CUTLASS 3.0 APIs for NVIDIA Hopper architecture + + This example shows how to perform GEMM where the input tensors A and B have different element types. CUTLASS currently supports upcasting + from a narrower (fewer bits) to a wider (more bits) type and utilizing the tensor core instruction for the wider type. For instance, when doing + INT8 x FP16, CUTLASS will convert INT8 -> FP16 and do math using FP16 tensor cores. Similarly, for INT4 x INT8, it will upcast to INT8 and issue math + using INT8 tensor cores. + + The narrower type always passes through the register file. Therefore, in cases where the narrower type is operand B, the collective will implicitly swap + A and B in the main loop. However, implicit swaps do not support TMA epilogues. Consequently, it is essential to consider this when constructing the epilogue, + as illustrated in this example. + + Note that in this example, we explicitly swap A and B in order to use TMA epilogues. We do this since TMA epilogues are more performant on problem sizes of interest. + + It is expected that the scale's K dimension be scale_k = ceil_div(problem_k, group_size). + + Scales are always expected to be MN major. This means the fastest changing dimension must be M if A is scaled or N if B is scaled. + + If A is being scaled, the scales should have shape [M, scale_k], while if B is scaled, it must have shape [N, scale_k]. + + The implementation only supports "group-wise" scales. However, we can make it work for per-column scales by setting the groups size + equal to the gemm problem K. + + Limitations: + 1) The narrow type must always be in K-major format. + 2) The scales and zeros must be MN major. That means if A is scaled, it must be column major, but if B is scaled it must be row major. + 3) The scales and the zeros must have the same layout and groupsize. + 4) The groupsize must be greater or equal to tile shape k. + 5) When dealing with 8-bit x {4-bit, 2-bit}, both inputs must be in K-major format. + 6) Currently, TMA epilogues cannot be used when the narrow type is the B operand. This limitation arises because the implementation always swaps the + operands to ensure that the narrow type passes through the register file, and TMA epilogues do not currently support implicit swap + transpose operations. + We plan to address this limitation in the future. However, we address this in the example by explicitly swapping and transposing the operands. + + Optimizing suggestions: + 1) Use a small tile size, since the register pressure for this GEMM (and RS GEMM in general) is high (it uses a lot of register space). + 2) Try avoid using scale or zero mode cause the computations will be the bottleneck. + + Examples: + + Runs the mixed input batched gemm (with batch size 2), converting B to the type of A (mode 0) + $ ./examples/55_hopper_mixed_dtype_gemm/55_hopper_mixed_dtype_gemm --m=2048 --n=2048 --k=2048 --l=2 --mode=0 + + Runs the mixed input gemm, and applies a scaling factor to B before mma (mode 1). Applies a vector of scales to the entire + matrix (group size is the same as the gemm k dimension). + $ ./examples/55_hopper_mixed_dtype_gemm/55_hopper_mixed_dtype_gemm --m=4096 --n=5120 --k=8192 --g=8192 --mode=1 + + Runs the mixed input gemm, and applies a scaling factor and adds a zero-point to B before mma (mode 2). Uses a group size of 128. + $ ./examples/55_hopper_mixed_dtype_gemm/55_hopper_mixed_dtype_gemm --m=2048 --n=5120 --k=8192 --g=128 --mode=2 +*/ + +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_compare.h" + +#include "helper.h" +#include "mixed_dtype_utils.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// +using MmaType = cutlass::half_t; +using QuantType = cutlass::float_e4m3_t; +constexpr int TileShapeK = 128 * 8 / sizeof_bits::value; + +// A matrix configuration +using ElementA = MmaType; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = QuantType; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// This example manually swaps and transposes, so keep transpose of input layouts +using LayoutA_Transpose = typename cutlass::layout::LayoutTranspose::type; +using LayoutB_Transpose = typename cutlass::layout::LayoutTranspose::type; + +using ElementZero = cutlass::half_t; +using ElementScale = cutlass::half_t; +using LayoutScale = cutlass::layout::RowMajor; + +// C/D matrix configuration +using ElementC = cutlass::half_t; // Element type for C and D matrix operands +using LayoutC = cutlass::layout::RowMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + +// D matrix configuration +using ElementD = ElementC; +using LayoutD = LayoutC; +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; + +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ElementCompute = float; // Element type for epilogue computation +using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape = Shape<_128,_128,cute::Int>; // Threadblock-level tile size +using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster +using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedCooperative; // Kernel to launch based on the default setting in the Collective Builder +using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecializedCooperative; +using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape, ClusterShape, + EpilogueTileType, + ElementAccumulator, ElementAccumulator, + // Transpose layout of D here since we use explicit swap + transpose + // the void type for C tells the builder to allocate 0 smem for the C matrix. + // We can enable this if beta == 0 by changing ElementC to void below. + ElementC, typename cutlass::layout::LayoutTranspose::type, AlignmentC, + ElementD, typename cutlass::layout::LayoutTranspose::type, AlignmentD, + EpilogueSchedule // This is the only epi supporting the required swap + transpose. + >::CollectiveOp; + +// ============================================================ MIXED INPUT NO SCALES ============================================================================ +// The collective will infer that the narrow type should be upcasted to the wide type. +// We swap A and B operands to the builder here +using CollectiveMainloopConvertOnly = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementB, LayoutB_Transpose, AlignmentB, + ElementA, LayoutA_Transpose, AlignmentA, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage)) + >, + KernelSchedule + >::CollectiveOp; + +using GemmKernelConvertOnly = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloopConvertOnly, + CollectiveEpilogue +>; + +using GemmConvertOnly = cutlass::gemm::device::GemmUniversalAdapter; + +// =========================================================== MIXED INPUT WITH SCALES =========================================================================== +// The Scale information must get paired with the operand that will be scaled. In this example, B is scaled so we make a tuple of B's information and the scale information. +using CollectiveMainloopScaleOnly = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + cute::tuple, LayoutB_Transpose, AlignmentB, + ElementA, LayoutA_Transpose, AlignmentA, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage)) + >, + KernelSchedule + >::CollectiveOp; + +using GemmKernelScaleOnly = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloopScaleOnly, + CollectiveEpilogue +>; + +using GemmScaleOnly = cutlass::gemm::device::GemmUniversalAdapter; + +// =========================================================== MIXED INPUT WITH SCALES AND ZEROS ================================================================== +// We specify scale + zero elements to indicate that we require both. Scales and biases have the same format. +using CollectiveMainloopScaleWithZeroPoint = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + cute::tuple, LayoutB_Transpose, AlignmentB, + ElementA, LayoutA_Transpose, AlignmentA, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage)) + >, + KernelSchedule + >::CollectiveOp; + +using GemmKernelScaleWithZeroPoint = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloopScaleWithZeroPoint, + CollectiveEpilogue +>; + +using GemmScaleWithZeroPoint = cutlass::gemm::device::GemmUniversalAdapter; +// ================================================================================================================================================================= + +using StrideA = cutlass::detail::TagToStrideA_t; +using StrideB = cutlass::detail::TagToStrideB_t; +using StrideC = typename GemmKernelScaleWithZeroPoint::StrideC; +using StrideD = typename GemmKernelScaleWithZeroPoint::StrideD; + +using StrideC_ref = cutlass::detail::TagToStrideC_t; +using StrideD_ref = cutlass::detail::TagToStrideC_t; + +// +// Data members +// + +/// Initialization +StrideA stride_A; +StrideB stride_B; +StrideC stride_C; +StrideC_ref stride_C_ref; +StrideD stride_D; +StrideD_ref stride_D_ref; +uint64_t seed; + +// Scale and Zero share a stride since the layout and shapes must be the same. +using StrideS = typename CollectiveMainloopScaleWithZeroPoint::StrideScale; +using StrideS_ref = cutlass::detail::TagToStrideB_t; +StrideS stride_S; +StrideS_ref stride_S_ref; + +cutlass::DeviceAllocation block_A; +cutlass::DeviceAllocation block_B; +cutlass::DeviceAllocation block_B_dq; +cutlass::DeviceAllocation block_scale; +cutlass::DeviceAllocation block_zero; +cutlass::DeviceAllocation block_C; +cutlass::DeviceAllocation block_D; +cutlass::DeviceAllocation block_ref_D; + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Initialize operands to be used in the GEMM and reference GEMM +void initialize(MixedDtypeOptions const& options) { + + auto shape_b = cute::make_shape(options.n, options.k, options.l); + int const scale_k = (options.k + options.g - 1) / options.g; + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, shape_b); + // Reverse stride here due to swap and transpose + stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.n, options.m, options.l)); + stride_C_ref = cutlass::make_cute_packed_stride(StrideC_ref{}, cute::make_shape(options.m, options.n, options.l)); + // Reverse stride here due to swap and transpose + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.n, options.m, options.l)); + stride_D_ref = cutlass::make_cute_packed_stride(StrideD_ref{}, cute::make_shape(options.m, options.n, options.l)); + + auto a_coord = cutlass::make_Coord(options.m * options.l, options.k); + auto b_coord = cutlass::make_Coord(options.k, options.n * options.l); + auto c_coord = cutlass::make_Coord(options.m * options.l, options.n); + + block_A.reset(a_coord.product()); + block_B.reset(b_coord.product()); + block_B_dq.reset(b_coord.product()); + block_C.reset(c_coord.product()); + block_D.reset(c_coord.product()); + block_ref_D.reset(c_coord.product()); + + block_scale.reset(scale_k * options.l * options.n); + block_zero.reset(scale_k * options.l * options.n); + + initialize_tensor(block_A, seed + 2022); + initialize_quant_tensor(block_B, seed + 2021); + initialize_tensor(block_C, seed + 2020); + initialize_scale(block_scale, options); + initialize_zero(block_zero, options); + + auto layout_B = make_layout(shape_b, stride_B); + + auto shape_scale_zero = cute::make_shape(options.n, scale_k, options.l); + stride_S = cutlass::make_cute_packed_stride(StrideS{}, cute::make_shape(options.n, scale_k, options.l)); + stride_S_ref = cutlass::make_cute_packed_stride(StrideS_ref{}, cute::make_shape(options.n, scale_k, options.l)); + auto layout_scale_zero = make_layout(shape_scale_zero, stride_S_ref); + + dequantize_weight(block_B_dq.get(), block_B.get(), layout_B, block_scale.get(), block_zero.get(), layout_scale_zero, options.g); +} + +/// Populates a Gemm::Arguments structure from the given commandline options +template +Args args_from_options(MixedDtypeOptions const& options) +{ +// Swap the A and B tensors, as well as problem shapes here. + if (options.mode == MixedDtypeGemmMode::ConvertOnly) { + return Args { + cutlass::gemm::GemmUniversalMode::kGemm, + {options.n, options.m, options.k, options.l}, + {block_B.get(), stride_B, block_A.get(), stride_A}, + {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D} + }; + } + else if (options.mode == MixedDtypeGemmMode::ScaleOnly) { + return Args { + cutlass::gemm::GemmUniversalMode::kGemm, + {options.n, options.m, options.k, options.l}, + {block_B.get(), stride_B, block_A.get(), stride_A, block_scale.get(), stride_S, options.g}, + {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D} + }; + } + else if (options.mode == MixedDtypeGemmMode::ScaleWithZeroPoint) { + return Args { + cutlass::gemm::GemmUniversalMode::kGemm, + {options.n, options.m, options.k, options.l}, + {block_B.get(), stride_B, block_A.get(), stride_A, block_scale.get(), stride_S, options.g, block_zero.get()}, + {{options.alpha, options.beta}, block_C.get(), stride_C, block_D.get(), stride_D} + }; + } else { + std::cerr << "Invalid mode " << options.mode << ". Must be 0, 1 or 2." << std::endl; + exit(-1); + } +} + +bool verify(MixedDtypeOptions const& options) { + // + // Compute reference output + // + + // In this example, we use the GPU default kernels as a reference (unfused scale) + // This avoids numerical differences due to different accumulation order. + + // Again, due to numerical differences, we must use fast acc here when the mma type is + // FP8 as the fused implementation only supports fast acc at the moment. + constexpr bool IsFP8Input = cute::is_same_v || cute::is_same_v; + using FP8Sched = cute::conditional_t(TileShape{}) == 64, cutlass::gemm::KernelTmaWarpSpecializedPingpongFP8FastAccum, cutlass::gemm::KernelTmaWarpSpecializedCooperativeFP8FastAccum>; + using ScheduleRef = cute::conditional_t; + + using CollectiveMainloopRef = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + MmaType, LayoutA, AlignmentA, + MmaType, LayoutB, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAuto, + ScheduleRef + >::CollectiveOp; + + using CollectiveEpilogueRef = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + cutlass::epilogue::NoSmemWarpSpecialized + >::CollectiveOp; + + using GemmKernelRef = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloopRef, + CollectiveEpilogueRef + >; + + using GemmRef = cutlass::gemm::device::GemmUniversalAdapter; + + typename GemmRef::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {options.m, options.n, options.k, options.l}, + {block_A.get(), stride_A, block_B_dq.get(), stride_B}, + {{options.alpha, options.beta}, block_C.get(), stride_C_ref, block_ref_D.get(), stride_D_ref} + }; + + // Run the gemm where the scaling is performed outside of the kernel. + GemmRef gemm_ref; + size_t workspace_size = GemmRef::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + CUTLASS_CHECK(gemm_ref.can_implement(arguments)); + CUTLASS_CHECK(gemm_ref.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm_ref.run()); + + // compare_reference + ElementD const epsilon(1e-2f); + ElementD const non_zero_floor(1e-4f); + bool passed = cutlass::reference::device::BlockCompareRelativelyEqual(block_ref_D.get(), block_D.get(), block_D.size(), epsilon, non_zero_floor); + return passed; +} + +/// Execute a given example GEMM computation +template +int run(MixedDtypeOptions &options) +{ + initialize(options); + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run()); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + MixedDtypeResult result; + result.passed = verify(options); + mixed_dtype_profiling(gemm, options, result); + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + if (!result.passed) { + exit(-1); + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example + // and must have compute capability at least 90. + if (__CUDACC_VER_MAJOR__ < 12) { + std::cerr << "This example requires CUDA 12 or newer.\n"; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (props.major < 9) { + std::cerr + << "This example requires a GPU of NVIDIA's Hopper Architecture or " + << "later (compute capability 90 or greater).\n"; + return 0; + } + // + // Parse options + // + + MixedDtypeOptions options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Evaluate CUTLASS kernels + // + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + if (options.mode == MixedDtypeGemmMode::ConvertOnly) { + std::cout << "Running in no scale mode." << std::endl; + run(options); + } + else if (options.mode == MixedDtypeGemmMode::ScaleOnly) { + if (options.g == options.k) { + std::cout << "Running in per-column scale mode." << std::endl; + } else { + std::cout << "Running in group scale mode." << std::endl; + } + run(options); + } + else if (options.mode == MixedDtypeGemmMode::ScaleWithZeroPoint) { + if (options.g == options.k) { + std::cout << "Running in per-column scale and zero mode." << std::endl; + } else { + std::cout << "Running in group scale and zero mode." << std::endl; + } + run(options); + } +#endif + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/55_hopper_mixed_dtype_gemm/CMakeLists.txt b/examples/55_hopper_mixed_dtype_gemm/CMakeLists.txt new file mode 100644 index 0000000000..23dca4f3fd --- /dev/null +++ b/examples/55_hopper_mixed_dtype_gemm/CMakeLists.txt @@ -0,0 +1,81 @@ + +# Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# Note that we set --iterations=0 for all tests below to disable the performance benchmarking. +# Only the correctness check will be run by these commands. + +set(TEST_DIRECT_BATCHED --m=2048 --n=2048 --k=2048 --l=2 --mode=0 --iterations=0) # Direct conversion + +set(TEST_SCALE_PERCOL --m=4096 --n=5120 --k=8192 --g=8192 --mode=1 --iterations=0) # Per Column scaling +set(TEST_SCALE_ZERO_PERCOL --m=4096 --n=5120 --k=8192 --g=8192 --mode=2 --iterations=0) # Per Column scaling + +set(TEST_SCALE_GROUP --m=2048 --n=5120 --k=8192 --g=512 --mode=1 --iterations=0) # Group-wise scaling +set(TEST_SCALE_ZERO_GROUPED --m=2048 --n=5120 --k=8192 --g=256 --mode=2 --iterations=0) # Group-wise scaling with zero-point + +set(TEST_SCALE_RESIDUE --m=128 --n=128 --k=320 --g=128 --mode=1 --iterations=0) # Final group has residue +set(TEST_SCALE_ZERO_RESIDUE --m=128 --n=128 --k=192 --g=128 --mode=2 --iterations=0) # Final group has residue + +set(TEST_ALPHA_BETA --alpha=0.5 --beta=0.7 --mode=2 --iterations=0) # Alpha and Beta with default shapes + + +cutlass_example_add_executable( + 55_hopper_mixed_dtype_gemm + 55_hopper_mixed_dtype_gemm.cu + TEST_COMMAND_OPTIONS + TEST_DIRECT_BATCHED + TEST_SCALE_PERCOL + TEST_SCALE_ZERO_PERCOL + TEST_SCALE_GROUP + TEST_SCALE_ZERO_GROUPED + TEST_SCALE_RESIDUE + TEST_SCALE_ZERO_RESIDUE + # TEST_ALPHA_BETA + ) + +cutlass_example_add_executable( + 55_hopper_int4_fp8_gemm + 55_hopper_int4_fp8_gemm.cu + TEST_COMMAND_OPTIONS + TEST_DIRECT_BATCHED + TEST_SCALE_PERCOL + TEST_SCALE_GROUP + TEST_SCALE_RESIDUE + # TEST_ALPHA_BETA + ) + + cutlass_example_add_executable( + 55_hopper_int4_bf16_gemm + 55_hopper_int4_bf16_gemm.cu + TEST_COMMAND_OPTIONS + TEST_DIRECT_BATCHED + TEST_SCALE_PERCOL + TEST_SCALE_GROUP + TEST_SCALE_RESIDUE + # TEST_ALPHA_BETA + ) diff --git a/examples/55_hopper_mixed_dtype_gemm/README.md b/examples/55_hopper_mixed_dtype_gemm/README.md new file mode 100644 index 0000000000..ecb4f41c97 --- /dev/null +++ b/examples/55_hopper_mixed_dtype_gemm/README.md @@ -0,0 +1,44 @@ +This example shows how to do mixed types GEMMs in CUTLASS. + +## High level overview +This example shows how to perform GEMMs on Hopper when A and B have different types. This implementation always passes the type with fewer bits through the register file and upcasts to the type with the higher bit count. + +When relying on `KernelScheduleAuto`, the main loop supporting different A and B types will be selected whenever the bit count of A is not equal to the bit count of B. Users can manually select the mixed type main loop and explicitly choose the scheduling policy by specifying one of the following schedules to the `CollectiveBuilder`: `KernelTmaWarpSpecialized`, `KernelTmaWarpSpecializedPingpong` or `KernelTmaWarpSpecializedCooperative`. + +This first version only supports mixed type GEMMs using TMA. + +## Performance + +While the example offers a harness for straightforward benchmarking, this initial implementation isn't optimized for performance in the majority of scenarios. We expect this implementation to be performant for `{fp16, bf16} x {int8, int4, int2}` and `{fp8} x {int4}` for problems that are compute bound. Additionally, we expect good performance for `fp16`, `bf16` or `fp32` scales and zero-points. For best performance, it is ideal to have the scales and zero-points be the same type as mma's type. + +The scale only mode for `fp8 x int4` is significantly slower than direct conversion mode. There is a lookup-table workaround targeting this mode, as shown in `55_hopper_int4_fp8_gemm.cu`. To use this feature, use `cutlass::Array` as the scale type in the collective builder. However, it requires modifications to the encoding of quantized weights and scale factors. Also, scale with zero point mode is not supported for now. + + +Additionally, it's recommended to reorder the narrow data type tensor such that elements read into register file by the same thread are contiguous in global and shared memory. The user can use the helper function `compute_memory_reordering_atom` and `reorder_tensor` to achieve this. See `55_hopper_int4_fp8_gemm.cu` and `55_hopper_int4_bf16_gemm.cu` for more details. + + +We are currently optimizing the following cases: +1. Memory bound cases for all types +2. `fp8 x {int2, uint2}` case + +## Limitations + +* The type that needs to be converted must go through the register file. This means that the collective will swap and transpose whenever the type with fewer bits is the B operand. The user must be aware of when these swaps happen. Note that TMA epilogues currently do not support *implicit* swap + transpose, so non-tma epilogues must be used in this case. We plan to relax this limitation in a future release. + +* The layout of the narrow type must be K-major. This means the following: + * Narrow type is the A operand: Must be Row-Major + * Narrow type is the B operand: Must be Column-Major + +* For 8-bit x 4-bit or 2-bit, both inputs must be K-major. + +* TMA requires an alignment of 128 bits. As a result, for a type with `B` bits, `B x TILE_K` must be a multiple of 128 bits. + +* The type of the scale and zero-point type must be two bytes or more. + +* The group size must be equal to gemm-k size (indicating a broadcast), or it must be a multiple of the threadblock-k size. + +## Upcoming features + +* Optimizations for memory bound cases. + +* Optimizations for scale and zero-point loading when the group size is not equal to the threadblock-k size. diff --git a/examples/55_hopper_mixed_dtype_gemm/mixed_dtype_utils.hpp b/examples/55_hopper_mixed_dtype_gemm/mixed_dtype_utils.hpp new file mode 100644 index 0000000000..55de3fabb3 --- /dev/null +++ b/examples/55_hopper_mixed_dtype_gemm/mixed_dtype_utils.hpp @@ -0,0 +1,391 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/util/command_line.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/util/reference/device/tensor_compare.h" + +#include "cute/tensor.hpp" + +#include +#include +#include "helper.h" + +enum MixedDtypeGemmMode { + ConvertOnly, + ScaleOnly, + ScaleWithZeroPoint +}; + +/// Command line options parsing +struct MixedDtypeOptions { + + bool help = false; + + float alpha = 1.0f; + float beta = 0.0f; + int iterations = 1000; + int warmup = 1000; + int mode = 1; + int m = 5120, n = 4096, k = 4096; + int g = 128; + int l = 1; + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("l", l); + cmd.get_cmd_line_argument("g", g); + cmd.get_cmd_line_argument("mode", mode); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations); + cmd.get_cmd_line_argument("warmup", warmup); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "55_hopper_mixed_dtype_gemm\n\n" + << " Hopper Mixed Data Type GEMM using a Warp Specialized kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= The number of independent gemm problems with mnk shape\n" + << " --g= The size of each group for the scales and zeros. To broadcast a vector of scales or zeros, set the group size to K.\n" + << " --mode= The mode to run the gemm. 0 does (A @ B), 1 means A @ (scale * B), 2 means A @ (scale * B + zero-point).\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Number of profiling iterations to perform.\n\n" + << " --warmup= Number of warmup iterations to perform.\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "55_hopper_mixed_dtype_gemm" << " --m=1024 --n=512 --k=1024 -g=1024 --l=10 --alpha=2 --mode=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k * l; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +/// Result structure +struct MixedDtypeResult +{ + double avg_runtime_ms = 0.0; + double gflops = 0.0; + cutlass::Status status = cutlass::Status::kSuccess; + cudaError_t error = cudaSuccess; + bool passed = false; + +}; + +/// Profiling Loop +template +void mixed_dtype_profiling( + Gemm& gemm, + MixedDtypeOptions const& options, + MixedDtypeResult& result) { + + if (options.iterations <= 0) return; + + cudaEvent_t start, stop; + cudaEventCreate(&start); + cudaEventCreate(&stop); + + std::vector runtimes; + runtimes.reserve(options.iterations); + + for (int iter = 0; iter < options.warmup + options.iterations; ++iter) { + cudaEventRecord(start); + CUTLASS_CHECK(gemm.run()); + cudaEventRecord(stop); + cudaEventSynchronize(stop); + + if (iter >= options.warmup) { + float milliseconds = 0; + cudaEventElapsedTime(&milliseconds, start, stop); + runtimes.push_back(milliseconds); + } + } + + cudaEventDestroy(start); + cudaEventDestroy(stop); + + // Compute average setup and runtime and GFLOPs. + result.avg_runtime_ms = std::accumulate(runtimes.begin(), runtimes.end(), 0.0f) / runtimes.size(); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << result.gflops << std::endl; + +} + +/// Helpers to initialize a block of device data +template +bool initialize_tensor( + cutlass::DeviceAllocation& block, + uint64_t seed = 2023) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } + else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } + else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } + else { + scope_max = 8; + scope_min = -8; + } + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, Element(scope_max), Element(scope_min)); + + return true; +} + +template +bool initialize_quant_tensor( + cutlass::DeviceAllocation& block, + uint64_t seed = 2023) { + + float scope_min = float(cutlass::platform::numeric_limits::lowest()); + float scope_max = float(cutlass::platform::numeric_limits::max()); + + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, Element(scope_max), Element(scope_min)); + + return true; +} + +template +bool initialize_scale( + cutlass::DeviceAllocation& block, + MixedDtypeOptions const& options, + uint64_t seed = 2023) { + + if (options.mode == MixedDtypeGemmMode::ConvertOnly) { + // No scales, so just initialize with 1 so we can use the same kernel to dequantize the data. + std::vector stage(block.size(), Element(1.0f)); + block.copy_from_host(stage.data()); + } + else { + float elt_max_f = float(cutlass::platform::numeric_limits::max()); + const float max_dequant_val = 4.f; + const float min_dequant_val = 0.5f; + + float scope_max(max_dequant_val / elt_max_f); + float scope_min(min_dequant_val / elt_max_f); + + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, Element(scope_max), Element(scope_min)); + } + return true; +} + +template +bool initialize_zero( + cutlass::DeviceAllocation& block, + MixedDtypeOptions const& options, + uint64_t seed = 2023) { + + if (options.mode == MixedDtypeGemmMode::ScaleWithZeroPoint) { + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, Element(2.0f), Element(-2.0f)); + } else { + // No bias, so just initialize with 1 so we can use the same kernel to dequantize the data. + std::vector stage(block.size(), Element(0.0f)); + block.copy_from_host(stage.data()); + } + return true; +} + +/// Dequantize the weights for verification + +template +__global__ void dequantize_weight_kernel(DequantizedElement* dq_buffer, + QuantizedElement const* q_buffer, + OperandLayout const operand_layout, + ElementScale const* scale_buffer, + ElementZero const* zero_buffer, + ScaleBroadCastLayout const broadcasted_scale_layout, + ThrLayout thr_layout) { + using namespace cute; + + // Represent the full tensors to gmem elements. + // These are expected to have shape [MN, K, L] + cute::Tensor gmem_op_dq = cute::make_tensor(cute::make_gmem_ptr(dq_buffer), operand_layout); + auto init_quantized_iterator = [&]() { + if constexpr (cute::sizeof_bits_v >= 8) { + return cute::make_gmem_ptr(q_buffer); + } else { + return cute::subbyte_iterator(q_buffer); + } + }; + cute::Tensor gmem_op_q = cute::make_tensor(init_quantized_iterator(), operand_layout); + // While the scales are expected to have shape [MN, G, L] but with a stride to allow broadcasting + // It is expected that K % G == 0 + cute::Tensor gmem_scale_broadcasted = cute::make_tensor(make_gmem_ptr(scale_buffer), broadcasted_scale_layout); + cute::Tensor gmem_zero_broadcasted = cute::make_tensor(make_gmem_ptr(zero_buffer), broadcasted_scale_layout); + + // Assign 1 thread per element in the thread block + auto blk_shape = make_shape(size<0>(thr_layout), _1{}, _1{}); // + auto blk_coord = make_coord(_, blockIdx.x, blockIdx.y); // (MN, K, L) + + // Tile across the block + auto gOp_dq = cute::local_tile(gmem_op_dq, blk_shape, blk_coord); + auto gScale = cute::local_tile(gmem_scale_broadcasted, blk_shape, blk_coord); + auto gZero = cute::local_tile(gmem_zero_broadcasted, blk_shape, blk_coord); + auto gOp_q = cute::local_tile(gmem_op_q, blk_shape, blk_coord); + + auto tOpDq_gOpDq = cute::local_partition(gOp_dq, thr_layout, threadIdx.x); + auto tScale_gScale = cute::local_partition(gScale, thr_layout, threadIdx.x); + auto tZero_gZero = cute::local_partition(gZero, thr_layout, threadIdx.x); + auto tOpQ_gOpQ = cute::local_partition(gOp_q, thr_layout, threadIdx.x); + + // Make a fragment of registers to hold gmem loads + cute::Tensor rmem_op_q = cute::make_fragment_like(tOpQ_gOpQ(_, _, _, 0)); + cute::Tensor rmem_scale = cute::make_fragment_like(tScale_gScale(_, _, _, 0)); + cute::Tensor rmem_zero = cute::make_fragment_like(tZero_gZero(_, _, _, 0)); + cute::Tensor rmem_op_dq = cute::make_fragment_like(tOpDq_gOpDq(_, _, _, 0)); + cute::Tensor rmem_op_scaled = cute::make_fragment_like(rmem_op_dq); + cute::Tensor rmem_zero_buf = cute::make_fragment_like(rmem_zero); + + cute::Tensor pred_id = cute::make_identity_tensor(shape(operand_layout)); + auto pred_blk_tile = cute::local_tile(pred_id, blk_shape, blk_coord); + auto pred_thr_partition = cute::local_partition(pred_blk_tile, thr_layout, threadIdx.x); + + const auto num_iters = cute::size<3>(tOpDq_gOpDq); + + for (int ii = 0; ii < num_iters; ++ii) { + const auto thread_offset = cute::get<0>(pred_thr_partition(0, 0, 0, ii)); + if (thread_offset < cute::size<0>(operand_layout)) { + cute::copy(tOpQ_gOpQ(_, _, _, ii), rmem_op_q); + cute::copy(tScale_gScale(_, _, _, ii), rmem_scale); + cute::copy(tZero_gZero(_, _, _, ii), rmem_zero); + cute::transform(rmem_op_q, rmem_op_scaled, [] (const QuantizedElement& elt) { return ElementScale(elt); } ); + cute::transform(rmem_zero, rmem_zero_buf, [] (const ElementZero& elt) { return ElementScale(elt); } ); + cute::transform(rmem_op_scaled, rmem_scale, rmem_op_scaled, multiplies{}); + cute::transform(rmem_op_scaled, rmem_zero_buf, rmem_op_scaled, plus{}); + cute::transform(rmem_op_scaled, rmem_op_dq, [] (const ElementScale& elt) { return DequantizedElement(elt); } ); + cute::copy(rmem_op_dq, tOpDq_gOpDq(_, _, _, ii)); + } + } +} + +template +void dequantize_weight(DequantizedElement* dq_buffer, + QuantizedElement const* q_buffer, + OperandLayout const operand_layout, + ElementScale const* scale_buffer, + ElementZero const* zero_buffer, + ScaleLayout const scale_layout, + int const group_size) { + + using namespace cute; + + constexpr int tpb = 128; + auto thr_layout = make_layout(make_shape(Int{})); + + const auto num_rows = get<0>(shape(operand_layout)); + const auto gemm_k = get<1>(shape(operand_layout)); // [MN, K, L] + const auto batches = get<2>(shape(operand_layout)); // [MN, K, L] + const auto scale_k = get<1>(shape(scale_layout)); // [MN, Scale_K, L] + + if (num_rows != size<0>(scale_layout)) { + std::cerr << "Invalid first dimension for scales. Must match first dim for weights." + << " But got shapes " << shape(operand_layout) << " " << shape(scale_layout) + << std::endl; + exit(-1); + } + + const auto scale_stride0 = get<0>(stride(scale_layout)); + const auto scale_stride1 = get<1>(stride(scale_layout)); + const auto scale_stride2 = get<2>(stride(scale_layout)); + + auto scale_shape_bcast = make_shape(num_rows, make_shape(group_size, scale_k), batches); + auto scale_stride_bcast = make_stride(scale_stride0, make_stride(0, scale_stride1), scale_stride2); + auto scale_layout_bcast = make_layout(scale_shape_bcast, scale_stride_bcast); + + const auto blocks_x = gemm_k; + const auto blocks_y = batches; + + dim3 blocks(blocks_x, blocks_y, 1); + dequantize_weight_kernel<<>>(dq_buffer, q_buffer, operand_layout, scale_buffer, zero_buffer, scale_layout_bcast, thr_layout); + CUDA_CHECK(cudaDeviceSynchronize()); +} diff --git a/examples/55_hopper_mixed_dtype_gemm/packed_scale.hpp b/examples/55_hopper_mixed_dtype_gemm/packed_scale.hpp new file mode 100644 index 0000000000..bd71e9cf28 --- /dev/null +++ b/examples/55_hopper_mixed_dtype_gemm/packed_scale.hpp @@ -0,0 +1,212 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include + + +#include "cutlass/util/device_memory.h" +#include "cutlass/integer_subbyte.h" +#include "cutlass/float8.h" +#include "cutlass/util/reference/device/tensor_fill.h" + +#include "cute/tensor.hpp" +#include "cute/util/type_traits.hpp" + +namespace cutlass +{ +template +class packed_scale_t { +public: + static_assert(cute::is_same_v || + cute::is_same_v || + cute::is_same_v || + cute::is_same_v, + "only 8 bit arithmetic types are supported."); + CUTLASS_HOST_DEVICE + explicit packed_scale_t(T val) { + if constexpr (!cute::is_unsigned_v) { + // Only pack negative values. The positive values are generated in flight in the mainloop. + storage[0] = pack4(T(float(val) * -8.f), T(float(val) * -7.f), T(float(val) * -6.f), T(float(val) * -5.f)); + storage[1] = pack4(T(float(val) * -4.f), T(float(val) * -3.f), T(float(val) * -2.f), -val); + } + else { + storage[0] = pack4(T(float(val) * 8.f), T(float(val) * 7.f), T(float(val) * 6.f), T(float(val) * 5.f)); + storage[1] = pack4(T(float(val) * 4.f), T(float(val) * 3.f), T(float(val) * 2.f), val); + } + } + CUTLASS_HOST_DEVICE + packed_scale_t() = default; + CUTLASS_HOST_DEVICE + explicit operator float() const { + return float(get()); + } + CUTLASS_HOST_DEVICE + bool operator==(packed_scale_t const& rhs) const { + return storage[0] == rhs.storage[0] && storage[1] == rhs.storage[1]; + } + CUTLASS_HOST_DEVICE + bool operator!=(packed_scale_t const& rhs) const { + return !(*this == rhs); + } + CUTLASS_HOST_DEVICE + friend packed_scale_t operator+(packed_scale_t const& lhs, packed_scale_t const& rhs) { + return packed_scale_t(lhs.get() + rhs.get()); + } + CUTLASS_HOST_DEVICE + friend packed_scale_t operator-(packed_scale_t const& lhs, packed_scale_t const& rhs) { + return packed_scale_t(lhs.get() - rhs.get()); + } + CUTLASS_HOST_DEVICE + friend packed_scale_t operator*(packed_scale_t const& lhs, packed_scale_t const& rhs) { + return packed_scale_t(lhs.get() * rhs.get()); + } + CUTLASS_HOST_DEVICE + friend packed_scale_t operator/(packed_scale_t const& lhs, packed_scale_t const& rhs) { + return packed_scale_t(lhs.get() / rhs.get()); + } + +private: + using Storage = uint32_t; + using Stage = uint8_t; + + Storage storage[2] {}; + + CUTLASS_HOST_DEVICE + static Storage pack4(T c1, T c2, T c3, T c4) { + Storage result = 0; + result |= (static_cast(reinterpret_cast(c4)) << 24); + result |= (static_cast(reinterpret_cast(c3)) << 16); + result |= (static_cast(reinterpret_cast(c2)) << 8); + result |= static_cast(reinterpret_cast(c1)); + return result; + } + CUTLASS_HOST_DEVICE + T get() const { + auto stage = static_cast(storage[0] >> 8); + #if defined(__CUDA_ARCH__) + return reinterpret_cast(stage); + #else + T tmp; + std::memcpy(&tmp, &stage, sizeof(Stage)); + return tmp; + #endif + } + CUTLASS_HOST_DEVICE + T get(int idx) const { + Stage stage; + if (idx < 4) stage = static_cast(storage[0] >> (8 * idx)); + else stage = static_cast(storage[1] >> (8 * idx - 32)); + #if defined(__CUDA_ARCH__) + return reinterpret_cast(stage); + #else + T tmp; + std::memcpy(&tmp, &stage, sizeof(Stage)); + return tmp; + #endif + } +}; +} + +/// Helpers to initialize scale lookup table + +// In the mainloop, PRMT selects 1 byte from only 8 bytes so the sign bit is handled in an extra PRMT. +// Here the encodings of positive values and negative values are unified (except for the sign bit). +// For instance, 1 becomes 0b0111, which is the same encoding as -1 (0b1111). +bool unify_quant_encoding( + cutlass::DeviceAllocation const& block_in, + cutlass::DeviceAllocation& block_out) { + + using StorageType = cutlass::int4b_t::Storage; + + if (block_in.size() != block_out.size()) { + std::cerr << "block_in and block_out must have same size.\n"; + return false; + } + constexpr int pack = cute::sizeof_bits_v / 4; + std::vector data(block_in.size() / pack); + cutlass::device_memory::copy_to_host(data.data(), (StorageType*)block_in.get(), block_in.size() / pack); + + for (auto&& d : data) { + StorageType out = 0; + StorageType mask = 0x0f; + for (int i = 0; i < pack; ++i) { + cutlass::int4b_t curr; + curr.storage = (d >> (i * 4)) & 0x0f; + switch (curr) { + case 1: curr.storage = StorageType(0b0111); break; // 2's complement + case 2: curr.storage = StorageType(0b0110); break; // 2's complement + case 3: curr.storage = StorageType(0b0101); break; // 2's complement + case 4: curr.storage = StorageType(0b0100); break; // 2's complement + case 5: curr.storage = StorageType(0b0011); break; // 2's complement + case 6: curr.storage = StorageType(0b0010); break; // 2's complement + case 7: curr.storage = StorageType(0b0001); break; // 2's complement + default: break; + } + out |= (curr.storage << (4 * i)) & mask; + mask <<= 4; + } + d = out; + } + + cutlass::device_memory::copy_to_device((StorageType*)block_out.get(), data.data(), block_out.size() / pack); + return true; +} + +template +bool initialize_packed_scale( + cutlass::DeviceAllocation const& block_in, + cutlass::DeviceAllocation > & block_out) { + + std::vector data_in(block_in.size()); + std::vector > data_out(block_in.size()); + try { + block_in.copy_to_host(data_in.data()); + } catch (cutlass::cuda_exception const& e) + { + std::cerr << "CUDA Error: " << cudaGetErrorString(e.cudaError()) << std::endl; + return false; + } + for (size_t i = 0; i < block_in.size(); ++i) + { + cutlass::packed_scale_t tmp(data_in[i]); + data_out[i] = reinterpret_cast const&>(tmp); + } + try { + block_out.copy_from_host(data_out.data()); + } catch (cutlass::cuda_exception const& e) + { + std::cerr << "CUDA Error: " << cudaGetErrorString(e.cudaError()) << std::endl; + return false; + } + return true; +} diff --git a/examples/55_hopper_mixed_dtype_gemm/reorder_utils.hpp b/examples/55_hopper_mixed_dtype_gemm/reorder_utils.hpp new file mode 100644 index 0000000000..de5a3d3fd0 --- /dev/null +++ b/examples/55_hopper_mixed_dtype_gemm/reorder_utils.hpp @@ -0,0 +1,162 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#include "cute/layout.hpp" +#include "cute/tensor.hpp" +#include "cute/arch/mma_sm90.hpp" + +#include "cutlass/util/device_memory.h" + +// Given a type of MMA instruction, compute a memory reordering atom that places all values +// owned by each thread in contiguous memory locations. This improves smem load vectorization, +// particularly for mixed dtype GEMMs where a narrow type is loaded in the thread/value order +// of the wider type and may result in inefficient sub-bank (8-bit or 16-bit) accesses. +// In addition, we can reorder the values across several MMA instructions to get even wider +// vectorization (AtomLayout parameter) and permute the values within each instruction to get +// more optimal conversion instruction sequences (ValLayout parameter). +template, + class ValLayout = cute::Layout> +constexpr auto compute_memory_reordering_atom(AtomLayout atom_layout = {}, ValLayout val_layout = {}) +{ + using namespace cute; + + static_assert(is_static_v, "ValLayout must be static"); + static_assert(is_static_v, "AtomLayout must be static"); + + // 1. Choose an MMA atom to access TV layout and MN shape + // Note: parameters like GMMA Major, TileShape, ElementC don't affect TV layout of A, use arbitrary + using MmaAtom = decltype(SM90::GMMA::rs_op_selector>()); + using MmaTraits = MMA_Traits; + auto mk_shape_mma = select<0,2>(typename MmaTraits::Shape_MNK{}); + auto tv_layout_mma = typename MmaTraits::ALayout{}; + static_assert(size<1>(tv_layout_mma) % size(val_layout) == 0, "Value layout must evenly divide the MMA value layout"); + + // 2. Create a single warp's TV layout from that of the whole MMA and invert to get (m,k -> thr,val) + // Note: this assumes A is partitioned between warps along M mode + auto tv_tiler_warp = make_shape(Int<32>{}, size<1>(tv_layout_mma)); + auto mk_shape_warp = shape_div(mk_shape_mma, size(typename MmaTraits::ThrID{}) / Int<32>{}); + auto tv_layout_mma_warp = make_layout_like(composition(tv_layout_mma, tv_tiler_warp)); + auto mk_layout_mma_warp = right_inverse(tv_layout_mma_warp).with_shape(mk_shape_warp); + + // 3. Repeat the warp layout NumAtoms times along K mode to get wider vectorization + auto mk_layout_mma_trgt = blocked_product(mk_layout_mma_warp, atom_layout); + + // 4. Compose with a contiguous layout of values in each thread (required for smem vectorization) + auto val_to_offset = logical_product(val_layout, size<1>(tv_layout_mma) / size(val_layout) * size(atom_layout)); + auto thr_to_offset = make_layout(size<0>(tv_layout_mma_warp)); + auto tv_to_offset = select<1,0>(logical_product(val_to_offset, thr_to_offset)); + auto layout_atom = composition(tv_to_offset, mk_layout_mma_trgt); + + return layout_atom; +} + +template +__global__ void reorder_tensor_kernel( + cute::Tensor S, + cute::Tensor D, + TiledCopy tiled_copy) +{ + using namespace cute; + + using T = typename EngineDst::value_type; + + Tensor gS = local_tile(S, TileShape{}, make_coord(blockIdx.x, _, blockIdx.z)); + Tensor gD = local_tile(D, TileShape{}, make_coord(blockIdx.x, _, blockIdx.z)); + + auto thread_copy = tiled_copy.get_slice(threadIdx.x); + Tensor tS = thread_copy.partition_S(gS); + Tensor tD = thread_copy.partition_D(gD); + + copy(tiled_copy, tS, tD); +} + +template +void reorder_tensor( + cute::Tensor S, + cute::Tensor D) +{ + using namespace cute; + + using T = typename EngineDst::value_type; + static_assert(is_same_v, T>, "Type mismatch"); + + // Construct a value layout that assigns at least 8 bits of contiguous elements in destination tensor to a thread + // This avoids a race condition when writing out subbyte types (e.g. int4b_t). + auto has_major_mode = [](auto s) { + return any_of(s, [](auto a){ return is_constant<1, decltype(a)>{}; }); + }; + static_assert(has_major_mode(stride<0>(LayoutDst{})) ^ has_major_mode(stride<1>(LayoutDst{})), + "Could not find stride-1 mode in destination layout"); + constexpr int N = shape_div(Int<8>{}, sizeof_bits{}); + auto val_layout = conditional_return(LayoutDst{}))>( + make_layout(make_shape(Int{}, Int<1>{}), GenColMajor{}), + make_layout(make_shape(Int<1>{}, Int{}), GenRowMajor{})); + + // Make a tiled copy with a simple row-major thread order and above layout + int constexpr NumThreads = 128; + auto const thr_layout = make_layout(make_shape(Int<1>{}, Int{})); + auto tiled_copy = make_tiled_copy(Copy_Atom{}, thr_layout, val_layout); + + // Assign a group of 16 rows to a threadblock; this matches the shuffle atom size for Hopper + using TileShape = Shape<_16>; + auto tiled_D = group_modes<3,rank_v>(tiled_divide(D, TileShape{})); + dim3 blocks{unsigned(size<1>(tiled_D)), 1u, unsigned(size<3>(tiled_D))}; + + reorder_tensor_kernel<<>>(S, D, tiled_copy); + CUDA_CHECK(cudaDeviceSynchronize()); +} + +// In-place version +template +void reorder_tensor( + T const* src, + LayoutSrc const& layout_src, + T * dst, + LayoutDst const& layout_dst) +{ + using namespace cute; + reorder_tensor(make_tensor(make_gmem_ptr(src), layout_src), + make_tensor(make_gmem_ptr(dst), layout_dst)); +} + +// In-place version +template +void reorder_tensor( + T * data, + LayoutSrc const& layout_src, + LayoutDst const& layout_dst) +{ + using namespace cute; + cutlass::DeviceAllocation temp(size(layout_src)); + reorder_tensor(data, layout_src, temp.get(), layout_dst); + cutlass::device_memory::copy_device_to_device(data, temp.get(), static_cast(size(layout_src))); +} diff --git a/examples/56_hopper_ptr_array_batched_gemm/56_hopper_ptr_array_batched_gemm.cu b/examples/56_hopper_ptr_array_batched_gemm/56_hopper_ptr_array_batched_gemm.cu new file mode 100644 index 0000000000..51ce970dbd --- /dev/null +++ b/examples/56_hopper_ptr_array_batched_gemm/56_hopper_ptr_array_batched_gemm.cu @@ -0,0 +1,550 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Hopper Ptr-Array Batched GEMM example using CUTLASS 3 APIs for NVIDIA Hopper architecture. + + This example demonstrates an implementation of Ptr-Array Batched GEMM using a TMA + GMMA + warp-specialized cooperative kernel. + The new feature showcased in this example is on-the-fly modification of TMA descriptors + to move between batches (represented by l). + + To run this example: + + $ ./examples/56_hopper_ptr_array_batched_gemm/56_hopper_ptr_array_batched_gemm --m=2048 --n=2048 --k=2048 --l=10 +*/ + +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/device/tensor_fill.h" + +#include "helper.h" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// + +// A matrix configuration +using ElementA = cutlass::half_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = cutlass::half_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// C/D matrix configuration +using ElementC = cutlass::half_t; // Element type for C and D matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size + +// Different configs for pingpong/cooperative +struct CooperativeConfig { + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperative; + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; + using TileShape = Shape<_256,_128,_64>; + using ClusterShape = Shape<_1,_2,_1>; +}; + +struct PingpongConfig { + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpong; + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; + using TileShape = Shape<_64,_128,_64>; + using ClusterShape = Shape<_1,_1,_1>; +}; + +template +struct GemmGivenSchedule { + using TileShape = typename ScheduleConfig::TileShape; // Threadblock-level tile size + using ClusterShape = typename ScheduleConfig::ClusterShape; // Shape of the threadblocks in a cluster + using KernelSchedule = typename ScheduleConfig::KernelSchedule; // Kernel to launch + using EpilogueSchedule = typename ScheduleConfig::EpilogueSchedule; // Epilogue to launch + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC, AlignmentC, + ElementC, LayoutC, AlignmentC, + EpilogueSchedule + >::CollectiveOp; + + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + cutlass::gemm::ArrayProblemShape>, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +}; + +using GemmKernel = GemmGivenSchedule::GemmKernel; +using Gemm = GemmGivenSchedule::Gemm; + +using GemmKernelPingpong = GemmGivenSchedule::GemmKernel; +using GemmPingpong = GemmGivenSchedule::Gemm; + + +// Reference device GEMM implementation type +using DeviceGemmReference = cutlass::reference::device::Gemm< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + ElementAccumulator>; + +using StrideA = typename Gemm::GemmKernel::StrideA; +using StrideB = typename Gemm::GemmKernel::StrideB; +using StrideC = typename Gemm::GemmKernel::StrideC; +using StrideD = typename Gemm::GemmKernel::StrideD; + +StrideA stride_A; +StrideB stride_B; +StrideC stride_C; +StrideD stride_D; +uint64_t seed; + +std::vector offset_A; +std::vector offset_B; +std::vector offset_C; +std::vector offset_D; + +cutlass::DeviceAllocation block_A; +cutlass::DeviceAllocation block_B; +cutlass::DeviceAllocation block_C; +cutlass::DeviceAllocation block_D; +cutlass::DeviceAllocation block_ref_D; + +cutlass::DeviceAllocation ptr_A; +cutlass::DeviceAllocation ptr_B; +cutlass::DeviceAllocation ptr_C; +cutlass::DeviceAllocation ptr_D; +cutlass::DeviceAllocation ptr_ref_D; + +#endif // defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help = false; + + float alpha = 1.0f; + float beta = 0.0f; + int iterations = 10; + int m = 1024, n = 512, k = 1024, l = 10; + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("l", l); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("iterations", iterations); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "56_hopper_ptr_array_batched_gemm\n\n" + << " Hopper FP32 GEMM using a Warp Specialized kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the batch count for Ptr-Array GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Number of profiling iterations to perform\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "56_hopper_ptr_array_batched_gemm" << " --m=1024 --n=512 --k=1024 --l=10 --alpha=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k * l; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +/// Result structure +struct Result +{ + double avg_runtime_ms = 0.0; + double gflops = 0.0; + cutlass::Status status = cutlass::Status::kSuccess; + cudaError_t error = cudaSuccess; + bool passed = false; +}; + +#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_block( + cutlass::DeviceAllocation& block, + uint64_t seed=2023) { + + Element scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = static_cast(2); + scope_min = static_cast(0); + } else if (bits_input <= 8) { + scope_max = static_cast(2); + scope_min = static_cast(-2); + } else { + scope_max = static_cast(8); + scope_min = static_cast(-8); + } + + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, scope_max, scope_min, 0); + + return true; +} + +/// Allocates device-side data +void allocate(const Options &options) { + int64_t total_elements_A = 0; + int64_t total_elements_B = 0; + int64_t total_elements_C = 0; + int64_t total_elements_D = 0; + + for (int32_t i = 0; i < options.l; ++i) { + + offset_A.push_back(total_elements_A); + offset_B.push_back(total_elements_B); + offset_C.push_back(total_elements_C); + offset_D.push_back(total_elements_D); + + int64_t elements_A = options.m * options.k; + int64_t elements_B = options.k * options.n; + int64_t elements_C = options.m * options.n; + int64_t elements_D = options.m * options.n; + + total_elements_A += elements_A; + total_elements_B += elements_B; + total_elements_C += elements_C; + total_elements_D += elements_D; + } + + block_A.reset(total_elements_A); + block_B.reset(total_elements_B); + block_C.reset(total_elements_C); + block_D.reset(total_elements_D); + block_ref_D.reset(total_elements_D); +} + +/// Initialize operands to be used in the GEMM and reference GEMM +void initialize(const Options &options) { + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, options.l)); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.m, options.n, options.l)); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, options.l)); + + // + // Assign pointers + // + + std::vector ptr_A_host(options.l); + std::vector ptr_B_host(options.l); + std::vector ptr_C_host(options.l); + std::vector ptr_D_host(options.l); + + for (int32_t i = 0; i < options.l; ++i) { + ptr_A_host.at(i) = block_A.get() + offset_A.at(i); + ptr_B_host.at(i) = block_B.get() + offset_B.at(i); + ptr_C_host.at(i) = block_C.get() + offset_C.at(i); + ptr_D_host.at(i) = block_D.get() + offset_D.at(i); + } + + ptr_A.reset(options.l); + ptr_A.copy_from_host(ptr_A_host.data()); + + ptr_B.reset(options.l); + ptr_B.copy_from_host(ptr_B_host.data()); + + ptr_C.reset(options.l); + ptr_C.copy_from_host(ptr_C_host.data()); + + ptr_D.reset(options.l); + ptr_D.copy_from_host(ptr_D_host.data()); + + initialize_block(block_A, seed + 2023); + initialize_block(block_B, seed + 2022); + initialize_block(block_C, seed + 2021); +} + +/// Populates a Gemm::Arguments structure from the given commandline options +template +typename GemmT::Arguments args_from_options(const Options &options) +{ + cutlass::KernelHardwareInfo hw_info; + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + typename GemmT::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kArray, + {{options.m, options.n, options.k, options.l}}, + {ptr_A.get(), stride_A, ptr_B.get(), stride_B}, + {{options.alpha, options.beta}, ptr_C.get(), stride_C, ptr_D.get(), stride_D}, + hw_info + }; + + return arguments; +} + +bool verify(const Options &options) { + bool passed = true; + for (int32_t i = 0; i < options.l; ++i) { + cutlass::TensorRef ref_A(block_A.get() + offset_A.at(i), Gemm::LayoutA::packed({options.m, options.k})); + cutlass::TensorRef ref_B(block_B.get() + offset_B.at(i), Gemm::LayoutB::packed({options.k, options.n})); + cutlass::TensorRef ref_C(block_C.get() + offset_C.at(i), Gemm::LayoutC::packed({options.m, options.n})); + cutlass::TensorRef ref_D(block_ref_D.get() + offset_D.at(i), Gemm::LayoutD::packed({options.m, options.n})); + + // + // Compute reference output + // + + // Create instantiation for device reference gemm kernel + DeviceGemmReference gemm_reference; + + // Launch device reference gemm kernel + gemm_reference( + {options.m, options.n, options.k}, + ElementAccumulator(options.alpha), + ref_A, + ref_B, + ElementAccumulator(options.beta), + ref_C, + ref_D); + + // Wait for kernel to finish + CUDA_CHECK(cudaDeviceSynchronize()); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + passed &= cutlass::reference::device::BlockCompareEqual(block_ref_D.get() + offset_D.at(i), block_D.get() + offset_D.at(i), options.m * options.n); + } + return passed; +} + +/// Execute a given example GEMM computation +template +int run(Options &options) +{ + allocate(options); + initialize(options); + + // Instantiate CUTLASS kernel depending on templates + GemmT gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = GemmT::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run()); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + result.passed = verify(options); + + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + + if (!result.passed) { + exit(-1); + } + + // Run profiling loop + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm.run()); + } + timer.stop(); + + // Compute average setup and runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl; + std::cout << " Batches : " << options.l << std::endl; + std::cout << " Alpha, Beta : " << options.alpha << ',' << options.beta << std::endl; + std::cout << " Avg runtime : " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS : " << result.gflops << std::endl; + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.3 Toolkit to run this example + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 3)) { + std::cerr << "This example requires CUDA 12.3 or newer.\n"; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (props.major < 9) { + std::cerr + << "This example requires a GPU of NVIDIA's Hopper Architecture or " + << "later (compute capability 90 or greater).\n"; + return 0; + } + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Evaluate CUTLASS kernels + // + +#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + std::cout << "\n*** Cooperative schedule ***" << std::endl; + run(options); + std::cout << "\n*** Pingpong schedule ***" << std::endl; + run(options); +#endif + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/examples/56_hopper_ptr_array_batched_gemm/CMakeLists.txt b/examples/56_hopper_ptr_array_batched_gemm/CMakeLists.txt new file mode 100644 index 0000000000..1f59ceb8a1 --- /dev/null +++ b/examples/56_hopper_ptr_array_batched_gemm/CMakeLists.txt @@ -0,0 +1,54 @@ + +# Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +set(TEST_SQUARE --m=2048 --n=2048 --k=2048 -l=10 --iterations=1) # Square problem sizes +set(TEST_SQUARE_LARGE_BATCH --m=2048 --n=2048 --k=2048 -l=500 --iterations=1) # Square problem sizes + +set(TEST_EPILOGUE --alpha=0.5 --beta=0.5 --iterations=1) # Default problem sizes +set(TEST_EPILOGUE_LARGE_BATCH --alpha=1.5 --beta=2.0 -l=500 --iterations=1) # Default problem sizes + +set(TEST_EPILOGUE_OP --beta=0.5 --iterations=1) # Default problem sizes w/ Epilogue Op test +set(TEST_EPILOGUE_OP_LARGE_BATCH --alpha=1.5 -l=500 --iterations=1) # Default problem sizes w/ Epilogue Op test + +set(TEST_SMALLK --m=2048 --n=5120 --k=128 --l=5 --iterations=1) # Small-k problem sizes +set(TEST_SMALLK_LARGE_BATCH --m=1024 --n=512 --k=64 --l=500 --iterations=1) # Small-k problem sizes + +cutlass_example_add_executable( + 56_hopper_ptr_array_batched_gemm + 56_hopper_ptr_array_batched_gemm.cu + TEST_COMMAND_OPTIONS + TEST_SQUARE + TEST_SQUARE_LARGE_BATCH + TEST_EPILOGUE + TEST_EPILOGUE_LARGE_BATCH + TEST_EPILOGUE_OP + TEST_EPILOGUE_OP_LARGE_BATCH + TEST_SMALLK + TEST_SMALLK_LARGE_BATCH + ) diff --git a/examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm.cu b/examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm.cu new file mode 100644 index 0000000000..7b20a33548 --- /dev/null +++ b/examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm.cu @@ -0,0 +1,772 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Hopper Grouped GEMM example using CUTLASS 3 APIs for NVIDIA Hopper architecture. + + This example demonstrates an implementation of Grouped GEMM using a TMA + GMMA + warp-specialized cooperative kernel. + For this example all scheduling work is performed on the device. + The new feature showcased in this example is on-the-fly modification of TMA descriptors + to move between groups/problem_count (represented by groups). + + To run this example: + + $ ./examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm --m=2048 --n=2048 --k=2048 --groups=10 + + The above example command makes all 10 groups to be sized at the given m, n, k sizes. + Skipping any of the problem dimensions randomizes it across the different groups. + Same applies for alpha and beta values that are randomized across the different groups. + + To run this example for a set of problems using the benchmark option: + + $ ./examples/57_hopper_grouped_gemm/57_hopper_grouped_gemm --benchmark=./test_benchmark.txt + + Where the test_benchmark.txt may look as such: + 0 256x512x128 + 1 256x512x512 + 2 512x256x128 + 3 256x256x128 + 4 256x512x1024 + 5 1024x512x128 and so on +*/ + +#include +#include +#include +#include +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/group_array_problem_shape.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/device/tensor_fill.h" + +#include "helper.h" + +using namespace cute; +using ProblemShape = cutlass::gemm::GroupProblemShape>; // per group +using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand +using ElementB = cutlass::float_e5m2_t; // Element type for B matrix operand +using ElementC = cutlass::half_t; // Element type for C and D matrix operands + +#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// + +// A matrix configuration +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Alignment of B matrix in units of elements (up to 16 bytes) + +// C/D matrix configuration +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Alignment of C matrix in units of elements (up to 16 bytes) + +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using StageCountType = cutlass::gemm::collective::StageCountAuto; // Stage count maximized based on the tile size + +// Different configs for pingpong/cooperative +struct CooperativeConfig { + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedCooperativeFP8FastAccum; + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedCooperative; + using TileShape = Shape<_256,_128,_128>; + using ClusterShape = Shape<_2,_2,_1>; +}; + +struct PingpongConfig { + using KernelSchedule = cutlass::gemm::KernelPtrArrayTmaWarpSpecializedPingpongFP8FastAccum; + using EpilogueSchedule = cutlass::epilogue::PtrArrayTmaWarpSpecializedPingpong; + using TileShape = Shape<_128,_128,_128>; + using ClusterShape = Shape<_2,_1,_1>; +}; + +template +struct GemmGivenSchedule { + using TileShape = typename ScheduleConfig::TileShape; // Threadblock-level tile size + using ClusterShape = typename ScheduleConfig::ClusterShape; // Shape of the threadblocks in a cluster + using KernelSchedule = typename ScheduleConfig::KernelSchedule; // Kernel to launch + using EpilogueSchedule = typename ScheduleConfig::EpilogueSchedule; // Epilogue to launch + + using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutC *, AlignmentC, + ElementC, LayoutC *, AlignmentC, + EpilogueSchedule, + cutlass::epilogue::fusion::LinearCombination + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA *, AlignmentA, + ElementB, LayoutB *, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + + using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue + >; + + using Gemm = cutlass::gemm::device::GemmUniversalAdapter; +}; + +using GemmKernel = GemmGivenSchedule::GemmKernel; +using Gemm = GemmGivenSchedule::Gemm; + +using GemmKernelPingpong = GemmGivenSchedule::GemmKernel; +using GemmPingpong = GemmGivenSchedule::Gemm; + +// Reference device GEMM implementation type +using DeviceGemmReference = cutlass::reference::device::Gemm< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + ElementAccumulator>; + +using StrideA = typename Gemm::GemmKernel::InternalStrideA; +using StrideB = typename Gemm::GemmKernel::InternalStrideB; +using StrideC = typename Gemm::GemmKernel::InternalStrideC; +using StrideD = typename Gemm::GemmKernel::InternalStrideD; + +// Host-side allocations +std::vector offset_A; +std::vector offset_B; +std::vector offset_C; +std::vector offset_D; + +std::vector stride_A_host; +std::vector stride_B_host; +std::vector stride_C_host; +std::vector stride_D_host; + +std::vector alpha_host; +std::vector beta_host; + +// Device-side allocations +cutlass::DeviceAllocation problem_sizes; + +cutlass::DeviceAllocation block_A; +cutlass::DeviceAllocation block_B; +cutlass::DeviceAllocation block_C; +cutlass::DeviceAllocation block_D; +cutlass::DeviceAllocation block_ref_D; + +cutlass::DeviceAllocation ptr_A; +cutlass::DeviceAllocation ptr_B; +cutlass::DeviceAllocation ptr_C; +cutlass::DeviceAllocation ptr_D; +cutlass::DeviceAllocation ptr_ref_D; + +cutlass::DeviceAllocation stride_A; +cutlass::DeviceAllocation stride_B; +cutlass::DeviceAllocation stride_C; +cutlass::DeviceAllocation stride_D; + +// Note, this is an array of pointers to alpha and beta scaling values per group +cutlass::DeviceAllocation alpha_device; +cutlass::DeviceAllocation beta_device; +cutlass::DeviceAllocation block_alpha; +cutlass::DeviceAllocation block_beta; + +#endif // defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help = false; + + float alpha = FLT_MAX; + float beta = FLT_MAX; + int iterations = 10; + int m = 1024, n = 2048, k = 512, groups = 10; + std::string benchmark_path; + std::vector problem_sizes_host; + int const tma_alignment_bits = 128; + int const alignment = tma_alignment_bits / cutlass::sizeof_bits::value; + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("groups", groups); + cmd.get_cmd_line_argument("alpha", alpha, FLT_MAX); + cmd.get_cmd_line_argument("beta", beta, FLT_MAX); + cmd.get_cmd_line_argument("iterations", iterations); + cmd.get_cmd_line_argument("benchmark", benchmark_path); + + // Decide how to initialize the problems + if (!benchmark_path.empty()) { + if (!benchmark_problems()) { + problem_sizes_host.clear(); + return; + } + } + else { + randomize_problems(cmd); + } + } + + void randomize_problems(cutlass::CommandLine &cmd) { + int cmd_line_m = -1, cmd_line_n = -1, cmd_line_k = -1; + cmd.get_cmd_line_argument("m", cmd_line_m); + cmd.get_cmd_line_argument("n", cmd_line_n); + cmd.get_cmd_line_argument("k", cmd_line_k); + + problem_sizes_host.reserve(groups); + + for (int i = groups; i > 0; i--) { + int m = cmd_line_m; + int n = cmd_line_n; + int k = cmd_line_k; + if (m < 1) { + m = alignment * ((rand() % 64) + 1); + } + if (n < 1) { + n = alignment * ((rand() % 64) + 1); + } + if (k < 1) { + k = alignment * ((rand() % 64) + 1); + } + problem_sizes_host.push_back({m, n, k}); + } + } + + /// Load a benchmark + bool benchmark_problems() { + std::ifstream file(benchmark_path); + if (!file.good()) { + return false; + } + + while (file.good()) { + + int idx = -1; + std::string extent_str; + + file >> idx >> extent_str; + + if (idx < 0 || extent_str.empty()) { + break; + } + + cutlass::gemm::GemmCoord extent; + std::vector tokens; + + cutlass::CommandLine::tokenize(tokens, extent_str, 'x'); + + for (int i = 0; i < int(tokens.size()); ++i) { + int x = std::atoi(tokens.at(i).c_str()); + + // round up + if (x % alignment) { + x += (alignment - (x % alignment)); + } + + extent.at(i) = x; + } + + if (extent.product()) { + problem_sizes_host.push_back({extent.m(), extent.n(), extent.k()}); + } + } + groups = static_cast(problem_sizes_host.size()); + + return true; + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "57_hopper_grouped_gemm\n\n" + << " Hopper FP8 Grouped GEMM using a Warp Specialized kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM for all groups\n" + << " --n= Sets the N extent of the GEMM for all groups\n" + << " --k= Sets the K extent of the GEMM for all groups\n" + << " --groups= Sets the number of individual GEMM problems for Grouped GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Number of profiling iterations to perform\n\n" + << " --benchmark= Executes a benchmark problem size.\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "57_hopper_grouped_gemm" << " --m=1024 --n=512 --k=1024 --groups=10 --alpha=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s, std::vector problem_sizes_host) const + { + // Number of real-valued multiply-adds + uint64_t fmas = uint64_t(); + + for (auto const & problem : problem_sizes_host) { + fmas += static_cast(get<0>(problem)) * + static_cast(get<1>(problem)) * + static_cast(get<2>(problem)); + } + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * uint64_t(fmas); + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +/// Result structure +struct Result +{ + double avg_runtime_ms = 0.0; + double gflops = 0.0; + cutlass::Status status = cutlass::Status::kSuccess; + cudaError_t error = cudaSuccess; + bool passed = false; +}; + +#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_block( + cutlass::DeviceAllocation& block, + uint64_t seed=2023) { + + Element scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = static_cast(2); + scope_min = static_cast(0); + } else if (bits_input <= 8) { + scope_max = static_cast(2); + scope_min = static_cast(-2); + } else { + scope_max = static_cast(8); + scope_min = static_cast(-8); + } + + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, scope_max, scope_min, 0); + + return true; +} + +/// Allocates device-side data +void allocate(const Options &options) { + int64_t total_elements_A = 0; + int64_t total_elements_B = 0; + int64_t total_elements_C = 0; + int64_t total_elements_D = 0; + + for (int32_t i = 0; i < options.groups; ++i) { + + auto problem = options.problem_sizes_host.at(i); + auto M = get<0>(problem); + auto N = get<1>(problem); + auto K = get<2>(problem); + + offset_A.push_back(total_elements_A); + offset_B.push_back(total_elements_B); + offset_C.push_back(total_elements_C); + offset_D.push_back(total_elements_D); + + int64_t elements_A = M * K; + int64_t elements_B = K * N; + int64_t elements_C = M * N; + int64_t elements_D = M * N; + + total_elements_A += elements_A; + total_elements_B += elements_B; + total_elements_C += elements_C; + total_elements_D += elements_D; + + stride_A_host.push_back(cutlass::make_cute_packed_stride(StrideA{}, {M, K, 1})); + stride_B_host.push_back(cutlass::make_cute_packed_stride(StrideB{}, {N, K, 1})); + stride_C_host.push_back(cutlass::make_cute_packed_stride(StrideC{}, {M, N, 1})); + stride_D_host.push_back(cutlass::make_cute_packed_stride(StrideD{}, {M, N, 1})); + + } + + block_A.reset(total_elements_A); + block_B.reset(total_elements_B); + block_C.reset(total_elements_C); + block_D.reset(total_elements_D); + block_ref_D.reset(total_elements_D); + block_alpha.reset(options.groups); + block_beta.reset(options.groups); +} + +/// Initialize operands to be used in the GEMM and reference GEMM +void initialize(const Options &options) { + + uint64_t seed = 2020; + + problem_sizes.reset(options.groups); + problem_sizes.copy_from_host(options.problem_sizes_host.data()); + + // + // Assign pointers + // + + std::vector ptr_A_host(options.groups); + std::vector ptr_B_host(options.groups); + std::vector ptr_C_host(options.groups); + std::vector ptr_D_host(options.groups); + std::vector ptr_alpha_host(options.groups); + std::vector ptr_beta_host(options.groups); + + for (int32_t i = 0; i < options.groups; ++i) { + ptr_A_host.at(i) = block_A.get() + offset_A.at(i); + ptr_B_host.at(i) = block_B.get() + offset_B.at(i); + ptr_C_host.at(i) = block_C.get() + offset_C.at(i); + ptr_D_host.at(i) = block_D.get() + offset_D.at(i); + alpha_host.push_back((options.alpha == FLT_MAX) ? static_cast((rand() % 5) + 1) : options.alpha); + beta_host.push_back((options.beta == FLT_MAX) ? static_cast(rand() % 5) : options.beta); + ptr_alpha_host.at(i) = block_alpha.get() + i; + ptr_beta_host.at(i) = block_beta.get() + i; + } + + ptr_A.reset(options.groups); + ptr_A.copy_from_host(ptr_A_host.data()); + + ptr_B.reset(options.groups); + ptr_B.copy_from_host(ptr_B_host.data()); + + ptr_C.reset(options.groups); + ptr_C.copy_from_host(ptr_C_host.data()); + + ptr_D.reset(options.groups); + ptr_D.copy_from_host(ptr_D_host.data()); + + stride_A.reset(options.groups); + stride_A.copy_from_host(stride_A_host.data()); + + stride_B.reset(options.groups); + stride_B.copy_from_host(stride_B_host.data()); + + stride_C.reset(options.groups); + stride_C.copy_from_host(stride_C_host.data()); + + stride_D.reset(options.groups); + stride_D.copy_from_host(stride_D_host.data()); + + alpha_device.reset(options.groups); + alpha_device.copy_from_host(ptr_alpha_host.data()); + beta_device.reset(options.groups); + beta_device.copy_from_host(ptr_beta_host.data()); + + initialize_block(block_A, seed + 2023); + initialize_block(block_B, seed + 2022); + initialize_block(block_C, seed + 2021); + block_alpha.copy_from_host(alpha_host.data()); + block_beta.copy_from_host(beta_host.data()); +} + +/// Populates a Gemm::Arguments structure from the given commandline options +template +typename GemmT::Arguments args_from_options(const Options &options, bool host_problem_shapes_available = true) +{ + cutlass::KernelHardwareInfo hw_info; + // Change device_id to another value if you are running on a machine with multiple GPUs and wish + // to use a GPU other than that with device ID 0. + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + + typename GemmT::Arguments arguments; + decltype(arguments.epilogue.thread) fusion_args; + + if (options.alpha != FLT_MAX && options.beta != FLT_MAX) { + // If both alpha/beta are provided (via cmd line args) and are scalar, i.e., same alpha/beta applies to all batches. + fusion_args.alpha = options.alpha; + fusion_args.beta = options.beta; + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + fusion_args.alpha_ptr_array = nullptr; + fusion_args.beta_ptr_array = nullptr; + // Single alpha and beta for all groups + fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 0}; + fusion_args.dBeta = {cute::_0{}, cute::_0{}, 0}; + } + else { + // If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups. + fusion_args.alpha = 0; + fusion_args.beta = 0; + fusion_args.alpha_ptr = nullptr; + fusion_args.beta_ptr = nullptr; + fusion_args.alpha_ptr_array = alpha_device.get(); + fusion_args.beta_ptr_array = beta_device.get(); + // One alpha and beta per each group + fusion_args.dAlpha = {cute::_0{}, cute::_0{}, 1}; + fusion_args.dBeta = {cute::_0{}, cute::_0{}, 1}; + } + + if (host_problem_shapes_available) { + arguments = typename GemmT::Arguments { + cutlass::gemm::GemmUniversalMode::kGrouped, + {options.groups, problem_sizes.get(), options.problem_sizes_host.data()}, + {ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()}, + {fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()}, + hw_info + }; + } + else { + arguments = typename GemmT::Arguments { + cutlass::gemm::GemmUniversalMode::kGrouped, + {options.groups, problem_sizes.get(), nullptr}, + {ptr_A.get(), stride_A.get(), ptr_B.get(), stride_B.get()}, + {fusion_args, ptr_C.get(), stride_C.get(), ptr_D.get(), stride_D.get()}, + hw_info + }; + } + + return arguments; +} + +bool verify(const Options &options) { + bool passed = true; + for (int32_t i = 0; i < options.groups; ++i) { + auto problem = options.problem_sizes_host.at(i); + auto M = get<0>(problem); + auto N = get<1>(problem); + auto K = get<2>(problem); + cutlass::TensorRef ref_A(block_A.get() + offset_A.at(i), Gemm::LayoutA::packed({M, K})); + cutlass::TensorRef ref_B(block_B.get() + offset_B.at(i), Gemm::LayoutB::packed({K, N})); + cutlass::TensorRef ref_C(block_C.get() + offset_C.at(i), Gemm::LayoutC::packed({M, N})); + cutlass::TensorRef ref_D(block_ref_D.get() + offset_D.at(i), Gemm::LayoutD::packed({M, N})); + + // + // Compute reference output + // + + // Create instantiation for device reference gemm kernel + DeviceGemmReference gemm_reference; + + // Launch device reference gemm kernel + gemm_reference( + {M, N, K}, + ElementAccumulator(alpha_host.at(i)), + ref_A, + ref_B, + ElementAccumulator(beta_host.at(i)), + ref_C, + ref_D); + + // Wait for kernel to finish + CUDA_CHECK(cudaDeviceSynchronize()); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + passed &= cutlass::reference::device::BlockCompareEqual(block_ref_D.get() + offset_D.at(i), block_D.get() + offset_D.at(i), M * N); + #if 0 + std::cout << "Group: " << i << " Status: " << passed << std::endl; + #endif + } + return passed; +} + +/// Execute a given example GEMM computation +template +int run(Options &options, bool host_problem_shapes_available = true) +{ + allocate(options); + initialize(options); + + // Instantiate CUTLASS kernel depending on templates + GemmT gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options, host_problem_shapes_available); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = GemmT::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run()); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + result.passed = verify(options); + + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + + if (!result.passed) { + exit(-1); + } + + // Run profiling loop + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm.run()); + } + timer.stop(); + + // Compute average setup and runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0, options.problem_sizes_host); + + std::cout << " Problem Sizes, Alpha, Beta " << std::endl; + for (int32_t i = 0; i < options.groups; ++i) { + std::cout << " " << options.problem_sizes_host.at(i); + std::cout << ", " << alpha_host.at(i) << ", " << beta_host.at(i) << std::endl; + } + std::cout << " Groups : " << options.groups << std::endl; + std::cout << " Avg runtime : " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS : " << result.gflops << std::endl; + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.3 Toolkit to run this example + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 3)) { + std::cerr << "This example requires CUDA 12.3 or newer.\n"; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (props.major < 9) { + std::cerr + << "This example requires a GPU of NVIDIA's Hopper Architecture or " + << "later (compute capability 90 or greater).\n"; + return 0; + } + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Evaluate CUTLASS kernels + // + +#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED) + std::cout << "\n*** Cooperative schedule ***" << std::endl; + run(options); + std::cout << "\n*** Cooperative schedule (host problem shapes unavailable) ***" << std::endl; + run(options, false /*host_problem_shapes_available*/); + std::cout << "\n*** Pingpong schedule ***" << std::endl; + run(options); + std::cout << "\n*** Pingpong schedule (host problem shapes unavailable) ***" << std::endl; + run(options, false /*host_problem_shapes_available*/); +#endif + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/57_hopper_grouped_gemm/CMakeLists.txt b/examples/57_hopper_grouped_gemm/CMakeLists.txt new file mode 100644 index 0000000000..1dadbfa813 --- /dev/null +++ b/examples/57_hopper_grouped_gemm/CMakeLists.txt @@ -0,0 +1,66 @@ +# Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# Note that we set --iterations=0 for all tests below to disable the performance benchmarking. +# Only the correctness check will be run by these commands. + +set(TEST_RANDOM --iterations=0) # Random problem sizes +set(TEST_RANDOM_LARGE_GROUP --groups=500 --iterations=0) # Random problem sizes + +set(TEST_EPILOGUE --alpha=0.5 --beta=0.5 --iterations=0) # Random problem sizes +set(TEST_EPILOGUE_LARGE_GROUP --alpha=1.5 --beta=2.0 --groups=500 --iterations=0) # Random problem sizes + +set(TEST_EPILOGUE_OP --beta=0.5 --iterations=1) # Random problem sizes +set(TEST_EPILOGUE_OP_LARGE_GROUP --alpha=1.5 --iterations=1) # Random problem sizes + +set(TEST_FIXED --m=2048 --n=5120 --k=8192 --groups=50 --iterations=0) # Fixed problem sizes +set(TEST_FIXED_LARGE_GROUP --m=2048 --n=512 --k=512 --groups=512 --iterations=0) # Fixed problem sizes + +set(TEST_SMALL --m=256 --n=128 --iterations=0) # Small problem sizes +set(TEST_SMALL_LARGE_GROUP --m=128 --n=128 --groups=500 --iterations=0) # Small problem sizes + +set(TEST_RANDOM_PERF --iterations=10) # Random problem sizes +set(TEST_RANDOM_PERF_LARGE_GROUP --groups=500 --iterations=10) # Random problem sizes + +cutlass_example_add_executable( + 57_hopper_grouped_gemm + 57_hopper_grouped_gemm.cu + TEST_COMMAND_OPTIONS + TEST_RANDOM + TEST_RANDOM_LARGE_GROUP + TEST_EPILOGUE + TEST_EPILOGUE_LARGE_GROUP + TEST_EPILOGUE_OP + TEST_EPILOGUE_OP_LARGE_GROUP + TEST_FIXED + TEST_FIXED_LARGE_GROUP + TEST_SMALL + TEST_SMALL_LARGE_GROUP + TEST_RANDOM_PERF + TEST_RANDOM_PERF_LARGE_GROUP + ) diff --git a/examples/58_ada_fp8_gemm/CMakeLists.txt b/examples/58_ada_fp8_gemm/CMakeLists.txt new file mode 100644 index 0000000000..2af325424f --- /dev/null +++ b/examples/58_ada_fp8_gemm/CMakeLists.txt @@ -0,0 +1,34 @@ + +# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +cutlass_example_add_executable( + 58_ada_fp8_gemm + ada_fp8_gemm.cu + ) diff --git a/examples/58_ada_fp8_gemm/ada_fp8_gemm.cu b/examples/58_ada_fp8_gemm/ada_fp8_gemm.cu new file mode 100644 index 0000000000..79bead365b --- /dev/null +++ b/examples/58_ada_fp8_gemm/ada_fp8_gemm.cu @@ -0,0 +1,826 @@ +/*************************************************************************************************** + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Example of running an Ada FP8 GEMM. + + In addition to using FP8 Tensor Core instructions, the Ada FP8 GEMM uses a distinct epilogue + that enables additional scaling of operands/outputs, storing a pre-activation-function output + tensor (called the "auxiliary" output), and computing the absolute maximum value of the + outputs. + + Pseudocode for this epilogue is as follows: + + Aux = ((alpha * scale_a * scale_b) * accumulator) + ((beta * scale_c) * source) + bias + D = activation(Aux) + + if Aux is fp8 type: + abs_max_output = max( abs(aux) | (for every aux in Aux)) + Aux = scale_aux * Aux + endif + + if D is fp8 type: + abs_max_output = max( abs(d) | (for every d in D)) + D = scale_d * D + endif + + Parameter Aux is optionally stored to global memory +*/ + +#include +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/util/command_line.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm_complex.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gemm.h" + +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/epilogue/thread/linear_combination_generic_with_scaling.h" +#include "cutlass/gemm/device/gemm_universal_with_absmax.h" + +#include "cutlass/layout/matrix.h" +#include "cutlass/matrix_coord.h" +#include "cutlass/gemm/device/gemm_universal_adapter.h" + + +using ElementA = cutlass::float_e4m3_t; +using ElementB = cutlass::float_e4m3_t; +using ElementOutput = cutlass::float_e4m3_t; +using ElementAuxOutput = ElementOutput; +using ElementAccumulator = float; +using LayoutA = cutlass::layout::RowMajor; +using LayoutB = cutlass::layout::ColumnMajor; +using LayoutC = cutlass::layout::RowMajor; +static int const kStages = 3; +static int const kAlignmentA = 16; +static int const kAlignmentB = 16; + +using EpilogueOutputOp = cutlass::epilogue::thread::LinearCombinationGenericWithScalingAndAbsMax< + cutlass::epilogue::thread::ReLu, + ElementOutput, + ElementAuxOutput, + 8, + ElementAccumulator, + ElementAccumulator + >; + +template +using Gemm_ = cutlass::gemm::device::GemmUniversalWithAbsMax< + ElementA, LayoutA, ElementB, LayoutB, ElementOutput, LayoutC, + ElementAccumulator, cutlass::arch::OpClassTensorOp, cutlass::arch::Sm89, + cutlass::gemm::GemmShape<128, 64, 128>, cutlass::gemm::GemmShape<64, 32, 128>, cutlass::gemm::GemmShape<16, 8, 32>, + EpilogueOutputOp, cutlass::gemm::threadblock::GemmIdentityThreadblockSwizzle<>, kStages, + kAlignmentA, kAlignmentB, MathOperator + >; + +using ElementAbsmax = typename EpilogueOutputOp::ElementAbsmax; + + +// Command line options parsing +struct Options { + + bool help; + bool error; + bool reference_check; + cutlass::gemm::GemmCoord problem_size; + + int iterations; + int warmup_iterations; + + bool scale_A; + bool scale_B; + bool scale_C; + + float alpha; + float beta; + + Options(): + help(false), + error(false), + reference_check(false), + iterations(20), + warmup_iterations(5), + scale_A(true), + scale_B(true), + scale_C(true), + alpha(1.f), + beta(0.f) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("iterations", iterations, 20); + cmd.get_cmd_line_argument("warmup_iterations", warmup_iterations, 5); + cmd.get_cmd_line_argument("reference-check", reference_check, false); + cmd.get_cmd_line_argument("scale-A", scale_A, true); + cmd.get_cmd_line_argument("scale-B", scale_B, true); + cmd.get_cmd_line_argument("scale-C", scale_C, true); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + + int m, n, k; + cmd.get_cmd_line_argument("m", m, 1024); + cmd.get_cmd_line_argument("n", n, 1024); + cmd.get_cmd_line_argument("k", k, 1024); + + problem_size = cutlass::gemm::GemmCoord{m, n, k}; + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "58_ada_fp8_gemm\n\n" + << " This example executes a GEMM using Ada FP8 Tensor Core operations. In addition to performing\n" + << " a normal GEMM, the kernel performs the following operations:\n" + << " Aux = ((alpha * scale_a * scale_b) * accumulator) + ((beta * scale_c) * source) + bias\n" + << " D = activation(Aux)\n\n" + << " if Aux is fp8:\n" + << " abs_max_output = max( abs(aux) | (for every aux in Aux) )\n" + << " Aux = scale_aux * Aux\n\n" + << " if D is fp8 type:\n" + << " abs_max_output = max( abs(d) | (for every d in D) )\n" + << " D = scale_d * D\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M dimension of the GEMM\n" + << " --n= Sets the N dimension of the GEMM\n" + << " --k= Sets the K dimension of the GEMM\n" + << " --scale-A= Whether to apply a scaling factor to operand A (default: true)\n" + << " --scale-B= Whether to apply a scaling factor to operand B (default: true)\n" + << " --scale-C= Whether to apply a scaling factor to operand C (default: true)\n" + << " --iterations= Number of profiling iterations to perform\n" + << " --warmup-iterations= Number of warmup iterations to perform\n" + << " --reference-check= If true, performs reference check\n"; + + return out; + } + + /// Compute performance in GFLOP/s + float gflops(float runtime_s) const { + // Two flops per multiply-add + return 2.0f * float(problem_size.product()) / float(1.0e9) / runtime_s; + } +}; + +/// Helper class to run the kernel +template +struct TestbedRunner { + + using ElementAccumulator = typename Gemm::ElementAccumulator; + using ElementCompute = typename Gemm::GemmKernel::Epilogue::OutputOp::ElementCompute; + using ElementScalingFactor = typename Gemm::EpilogueOutputOp::ElementScalingFactor; + + static bool const kScaleAux = Gemm::EpilogueOutputOp::kIsScalingAndAmaxAuxOutputNeeded; + static bool const kScaleOutput = Gemm::EpilogueOutputOp::kIsScalingAndAmaxOutputNeeded; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint64_t seed; + + cutlass::HostTensor tensor_A; + cutlass::HostTensor tensor_B; + cutlass::HostTensor tensor_C; + cutlass::HostTensor tensor_Aux; + cutlass::HostTensor tensor_D; + cutlass::HostTensor tensor_Vector; + cutlass::HostTensor tmp_D; + cutlass::HostTensor reference_D; + cutlass::HostTensor reference_Aux; + cutlass::HostTensor scale_A; + cutlass::HostTensor scale_B; + cutlass::HostTensor scale_C; + cutlass::HostTensor scale_D; + cutlass::HostTensor scale_Aux; + cutlass::HostTensor abs_max_Aux; + cutlass::HostTensor abs_max_D; + cutlass::HostTensor reference_abs_max_Aux; + cutlass::HostTensor reference_abs_max_D; + + // + // Methods + // + + TestbedRunner( + bool scaleA = true, + bool scaleB = true, + bool scaleC = true, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint64_t seed_ = 2080 + ): + init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + /// Helper to initialize scaling factors + template + bool initialize_scale_factor(cutlass::TensorView view, uint64_t seed, int bits=0) { + cutlass::reference::host::TensorFillRandomUniform(view, seed, double(1.), double(0.), bits); + return true; + } + + /// Helper to initialize a tensor view + template + bool initialize_tensor( + cutlass::TensorView view, + cutlass::Distribution::Kind dist_kind, + uint64_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } else { + scope_max = 8; + scope_min = -8; + } + + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::Identity) { + + cutlass::reference::host::TensorFillIdentity(view); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::host::TensorFillRandomGaussian(view, seed, 0, 0.5); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + cutlass::reference::host::BlockFillSequential( + view.data(), view.capacity()); + } + else { + std::cerr << "Not implemented"; + return false; + } + + return true; + } + + /// Initializes data structures + void initialize(const Options& options) { + // + // Allocate the GEMM workspace + // + + tensor_A.resize(options.problem_size.mk()); + tensor_B.resize(options.problem_size.kn()); + tensor_C.resize(options.problem_size.mn()); + tensor_D.resize(options.problem_size.mn()); + tensor_Vector.resize({1, options.problem_size.n()}); + reference_D.resize(options.problem_size.mn(), false); + tmp_D.resize(options.problem_size.mn(), false); + + initialize_tensor(tensor_A.host_view(), init_A, seed + 2019); + initialize_tensor(tensor_B.host_view(), init_B, seed + 2018); + initialize_tensor(tensor_C.host_view(), init_C, seed + 2017); + initialize_tensor(tensor_Vector.host_view(), init_C, seed + 2020); + + // It is possible to randomly initialize to all zeros, so override this with non-zeros + // in the upper left corner of each operand. + cutlass::Coord<2> origin(0); + tensor_A.host_view().at(origin) = typename Gemm::ElementA(1); + tensor_B.host_view().at(origin) = typename Gemm::ElementB(1); + tensor_C.host_view().at(origin) = typename Gemm::ElementC(1); + tensor_Vector.host_view().at(origin) = typename Gemm::ElementC(1); + + cutlass::reference::host::TensorFill(tensor_D.host_view()); + cutlass::reference::host::TensorCopy(reference_D.host_view(), tensor_C.host_view()); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); + tensor_Vector.sync_device(); + + int scale_bits = 2; + if (options.scale_A) { + scale_A.resize({1, 1}); + initialize_scale_factor(scale_A.host_view(), seed + 2021, scale_bits); + scale_A.sync_device(); + } + + if (options.scale_B) { + scale_B.resize({1, 1}); + initialize_scale_factor(scale_B.host_view(), seed + 2022, scale_bits); + scale_B.sync_device(); + } + + if (options.scale_C) { + scale_C.resize({1, 1}); + initialize_scale_factor(scale_C.host_view(), seed + 2023, scale_bits); + scale_C.sync_device(); + } + + if (kScaleOutput) { + scale_D.resize({1, 1}); + initialize_scale_factor(scale_D.host_view(), seed + 2024, scale_bits); + scale_D.sync_device(); + + abs_max_D.resize({1, 1}); + cutlass::reference::host::TensorFill(abs_max_D.host_view()); + abs_max_D.sync_device(); + + reference_abs_max_D.resize({1, 1}); + } + + if (kScaleAux) { + tensor_Aux.resize(options.problem_size.mn()); + cutlass::reference::host::TensorFill(tensor_Aux.host_view()); + tensor_Aux.sync_device(); + + scale_Aux.resize({1, 1}); + initialize_scale_factor(scale_Aux.host_view(), seed + 2025, scale_bits); + scale_Aux.sync_device(); + + abs_max_Aux.resize({1, 1}); + cutlass::reference::host::TensorFill(abs_max_Aux.host_view()); + abs_max_Aux.sync_device(); + + reference_Aux.resize(options.problem_size.mn(), false); + reference_abs_max_Aux.resize({1, 1}); + } + } + + /// Compares computed reference with device reference and outputs to a file if incorrect + bool compare_reference(const Options& options) { + + tensor_D.sync_host(); + + bool passed = cutlass::reference::host::TensorEquals(reference_D.host_view(), tensor_D.host_view()); + + if (kScaleAux) { + tensor_Aux.sync_host(); + abs_max_Aux.sync_host(); + passed &= cutlass::reference::host::TensorEquals(reference_Aux.host_view(), tensor_Aux.host_view()); + passed &= cutlass::reference::host::TensorEquals(abs_max_Aux.host_view(), reference_abs_max_Aux.host_view()); + } + + if (kScaleOutput) { + abs_max_D.sync_host(); + passed &= cutlass::reference::host::TensorEquals(abs_max_D.host_view(), reference_abs_max_D.host_view()); + } + + if (!passed) { + std::cerr << "Reference check failed" << std::endl; + + std::string output_file = "testbed_with_amax_errors.txt"; + std::ofstream file(output_file); + + file + << "problem: " << options.problem_size + << ", alpha: " << options.alpha << ", beta: " << options.beta << "\n\n"; + + file + << "A =\n" << tensor_A.host_view() + << "\nB =\n" << tensor_B.host_view() + << "\nC =\n" << tensor_C.host_view() + << "\nVector =\n" << tensor_Vector.host_view() + << "\nScaleA = " << scale_A.host_view() + << "\nScaleB = " << scale_B.host_view() + << "\nScaleC = " << scale_C.host_view() + << "\nScaleD = " << scale_D.host_view() + << "\nScaleAux = " << scale_Aux.host_view() + << "\n\nReference D =\n" << reference_D.host_view() + << "\nComputed D =\n" << tensor_D.host_view(); + if (kScaleAux) { + file + << "\n\nReference Aux =\n" << reference_Aux.host_view() + << "\nComputed Aux =\n" << tensor_Aux.host_view() + << "\n\nReference Absmax Aux = " << reference_abs_max_Aux.host_view() + << "\nComputed Absmax Aux = " << abs_max_Aux.host_view(); + } + if (kScaleOutput) { + file + << "\n\nReference Absmax D = " << reference_abs_max_D.host_view() + << "\nComputed Absmax D = " << abs_max_D.host_view(); + } + + std::cerr << "Dumped results to " << output_file << std::endl; + + } + + return passed; + } + + /// Verifies the result is a GEMM + bool verify(const Options& options) { + + cutlass::Coord<2> origin(0); + ElementCompute scaled_alpha = options.alpha; + if (options.scale_A) { + scaled_alpha *= scale_A.host_view().at(origin); + } + if (options.scale_B) { + scaled_alpha *= scale_B.host_view().at(origin); + } + + ElementCompute scaled_beta = options.beta; + if (options.scale_C) { + scaled_beta *= scale_C.host_view().at(origin); + } + + // + // Verify + // + + cutlass::reference::host::GemmComplex< + typename Gemm::ElementA, typename Gemm::LayoutA, + typename Gemm::ElementB, typename Gemm::LayoutB, + typename Gemm::ElementC, typename Gemm::LayoutC, + ElementCompute, ElementAccumulator, ElementAccumulator + >( + options.problem_size, + scaled_alpha, + tensor_A.host_ref(), + Gemm::kTransformA, + tensor_B.host_ref(), + Gemm::kTransformB, + scaled_beta, + tensor_C.host_ref(), + tmp_D.host_ref(), + ElementAccumulator(0) + ); + + ElementCompute tmp_abs_max_Aux(0.); + ElementCompute tmp_abs_max_D(0.); + + cutlass::NumericConverter cvt_c_to_compute; + cutlass::NumericConverter cvt_accum_to_compute; + cutlass::NumericConverter cvt_compute_to_accum; + cutlass::NumericConverter cvt_compute_to_d; + cutlass::NumericConverter cvt_compute_to_aux; + + cutlass::absolute_value_op abs; + cutlass::maximum_with_nan_propogation max; + cutlass::epilogue::thread::ReLu act; + + ElementScalingFactor d_scale = kScaleOutput ? scale_D.host_view().at(origin) : ElementScalingFactor(1.); + + for (int m = 0; m < options.problem_size.m(); ++m) { + for (int n = 0; n < options.problem_size.n(); ++n) { + ElementCompute intermediate = cvt_accum_to_compute(tmp_D.host_view().at({m, n})); + ElementCompute bias = cvt_c_to_compute(tensor_Vector.host_view().at({0, n})); + ElementCompute aux = intermediate + bias; + ElementCompute d = act(aux); + tmp_abs_max_Aux = max(abs(aux), tmp_abs_max_Aux); + tmp_abs_max_D = max(abs(d), tmp_abs_max_D); + reference_D.host_view().at({m, n}) = cvt_compute_to_d(d * d_scale); + + if (kScaleAux) { + reference_Aux.host_view().at({m, n}) = cvt_compute_to_aux(aux * scale_Aux.host_view().at(origin)); + } + } + } + + if (kScaleAux) { + reference_abs_max_Aux.host_view().at(origin) = cvt_compute_to_accum(tmp_abs_max_Aux); + } + + if (kScaleOutput) { + reference_abs_max_D.host_view().at(origin) = cvt_compute_to_accum(tmp_abs_max_D); + } + + return compare_reference(options); + } + + /// Returns true if the CUDA device is sufficient to execute the kernel. + bool sufficient() const { + + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 4)) { + std::cerr << "This example requires CUDA 12.4 or greater." << std::endl; + return false; + } + + size_t smem_size = sizeof(typename Gemm::GemmKernel::SharedStorage); + + cudaDeviceProp properties; + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + + if (result != cudaSuccess) { + std::cerr << "cudaGetDevice() failed with error: " << cudaGetErrorString(result) << std::endl; + return false; + } + + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() failed with error: " << cudaGetErrorString(result) << std::endl; + return false; + } + + if (properties.major < 8 || (properties.major == 8 && properties.minor < 9)) { + std::cerr << "CUTLASS's Ada FP8 GEMM example requires a device of compute capability 89 or higher.\n" << std::endl; + return false; + } + + if (properties.sharedMemPerBlockOptin < smem_size) { + std::cerr << "Insufficient shared memory. Need " << smem_size + << ", but device only has " << properties.sharedMemPerBlockOptin << std::endl; + return false; + } + + return true; + } + + /// Executes one test + bool run(Options& options) + { + + // Waive test if insufficient CUDA device + if (!sufficient()) { + std::cerr << "Insufficient resources to run the kernel." << std::endl; + return false; + } + + this->initialize(options); + + // + // Initialize the GEMM operator + // + + typename Gemm::EpilogueOutputOp::Params::ActivationParams activation_params{ + ElementCompute(options.alpha), + ElementCompute(options.beta) + }; + typename Gemm::EpilogueOutputOp::Params epilogue_params{ + activation_params, + scale_A.device_data(), + scale_B.device_data(), + scale_C.device_data(), + scale_D.device_data(), + scale_Aux.device_data(), + abs_max_Aux.device_data(), + abs_max_D.device_data() + }; + + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + options.problem_size, + /* batch_count = */ 1, + epilogue_params, + tensor_A.device_data(), + tensor_B.device_data(), + tensor_C.device_data(), + tensor_D.device_data(), + tensor_Aux.device_data(), + tensor_Vector.device_data(), + options.problem_size.m() * options.problem_size.k(), + options.problem_size.n() * options.problem_size.k(), + options.problem_size.m() * options.problem_size.n(), + options.problem_size.m() * options.problem_size.n(), + (int)options.problem_size.m(), // Batch stride vector + tensor_A.layout().stride(0), + tensor_B.layout().stride(0), + tensor_C.layout().stride(0), + tensor_D.layout().stride(0), + (int64_t)0 // Leading dimension of vector. This must be 0 + }; + + Gemm gemm_op; + + cutlass::Status status = gemm_op.can_implement(arguments); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Gemm::can_implement() failed" << std::endl; + return false; + } + + size_t workspace_size = Gemm::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + status = gemm_op.initialize(arguments, workspace.get()); + if (status != cutlass::Status::kSuccess) { + std::cerr << "Gemm::initialize() failed" << std::endl; + return false; + } + + // + // Run the GEMM + // + + status = gemm_op(); + + if (status != cutlass::Status::kSuccess) { + std::cerr << "Gemm::run() failed" << std::endl; + return false; + } + + cudaError_t cuda_error = cudaDeviceSynchronize(); + if (cuda_error != cudaSuccess) { + std::cerr << "CUDA error: " << cudaGetErrorString(cuda_error) << std::endl; + return false; + } + + // + // Verify + // + + bool passed = true; + if (options.reference_check) { + passed &= this->verify(options); + } else { + std::cout << "Skipped reference check" << std::endl; + } + + // + // Warm up + // + + for (int i = 0; i < options.warmup_iterations; ++i) { + gemm_op(); + } + + // + // Profile + // + + cudaEvent_t events[2]; + cudaError_t error; + for (auto & event : events) { + error = cudaEventCreate(&event); + if (error != cudaSuccess) { + std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(error) << std::endl; + return false; + } + } + + // Record an event at the start of a series of GEMM operations + error = cudaEventRecord(events[0]); + if (error != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(error) << std::endl; + return false; + } + + // Run profiling loop + for (int iter = 0; iter < options.iterations; ++iter) { + gemm_op(); + } + + // Record an event when the GEMM operations have been launched. + error = cudaEventRecord(events[1]); + if (error != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(error) << std::endl; + return false; + } + + // Wait for work on the device to complete. + error = cudaEventSynchronize(events[1]); + if (error != cudaSuccess) { + std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(error) << std::endl; + return false; + } + + // Measure elapsed runtime + float runtime_ms = 0; + error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); + if (error != cudaSuccess) { + std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(error) << std::endl; + return false; + } + + // Compute average runtime and GFLOPs. + runtime_ms = runtime_ms / float(options.iterations); + float gflops = options.gflops(runtime_ms / 1000.0f); + + std::cout << "Problem size: " << options.problem_size.m() << 'x' << options.problem_size.n() << 'x' << options.problem_size.k() << std::endl; + std::cout << "Runtime (ms): " << runtime_ms << std::endl; + std::cout << "GFLOPs/sec: " << gflops << std::endl; + + // Cleanup + for (auto event : events) { + (void)cudaEventDestroy(event); + } + + return passed; + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const** argv) { + + cudaDeviceProp props; + + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 4) || + (props.major != 8 && props.minor != 9)) { + + // + // This example requires an NVIDIA Ada-architecture GPU. + // + + std::cout + << "CUTLASS's FP8 SM89 example requires a GPU of NVIDIA's Ada architecture " + << "and CUDA toolkit version 12.4 or later.\n"; + + return 0; + } + + // + // Parse options + // + + Options options; + + options.parse(argc, argv); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + + std::cout << "Running GEMM with staged accumulation (OpMultiplyAdd)" << std::endl; + std::cout << "=====================================================" << std::endl; + TestbedRunner> testbed_staged_accum; + bool passed = testbed_staged_accum.run(options); + + if (passed) { + std::cout << "Passed" << std::endl; + } else { + std::cout << "Failed" << std::endl; + } + + std::cout << "\nRunning GEMM with fast accumulation (OpMultiplyAddFastAccum)" << std::endl; + std::cout << "============================================================" << std::endl; + TestbedRunner> testbed_fast_accum; + passed = testbed_fast_accum.run(options); + + if (passed) { + std::cout << "Passed" << std::endl; + } else { + std::cout << "Failed" << std::endl; + } + + return 0; +} diff --git a/examples/59_ampere_gather_scatter_conv/CMakeLists.txt b/examples/59_ampere_gather_scatter_conv/CMakeLists.txt new file mode 100644 index 0000000000..ce22cd1f37 --- /dev/null +++ b/examples/59_ampere_gather_scatter_conv/CMakeLists.txt @@ -0,0 +1,40 @@ +# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +if (NOT MSVC) + +cutlass_example_add_executable( + 59_ampere_gather_scatter_conv + ampere_gather_scatter_conv.cu +) + +if (CUTLASS_ENABLE_OPENMP_TESTS AND OpenMP_CXX_FOUND) + target_link_libraries(59_ampere_gather_scatter_conv PRIVATE OpenMP::OpenMP_CXX) +endif() + +endif() diff --git a/examples/59_ampere_gather_scatter_conv/README.md b/examples/59_ampere_gather_scatter_conv/README.md new file mode 100644 index 0000000000..4aac053639 --- /dev/null +++ b/examples/59_ampere_gather_scatter_conv/README.md @@ -0,0 +1,209 @@ +# Example 59: Ampere gather/scatter convolution + +CuTe and CUTLASS 3.x based Ampere convolution forward propagation kernel capable of operating on both affine and gather/scatter tensors. + +Example executions: +```sh +./59_ampere_gather_scatter_conv +./59_ampere_gather_scatter_conv --n=108 +./59_ampere_gather_scatter_conv --n=4096 --i=1 +./59_ampere_gather_scatter_conv --n=1080 --i=1000 +./59_ampere_gather_scatter_conv --n=131072 --i=1000 --no-check +``` + +This example demonstrates a few super cool features of CUTLASS and CuTe. It shows off +1. A dense conv 3D fprop kernel written as a single file ... +2. ... that leverages off-the-shelf CUTLASS collectives to show how custom kernels can use collectives ... +3. ... and uses the exact same templated kernel to also stamp out a gather/scatter 3D fprop conv ... +4. ... while getting near peak performance of the Ampere class tensor core on Ampere and Ada GPUs ... +5. ... by using static cute shapes and strides in case problem shapes are known at compile time. + +## A dense conv 3D fprop kernel written in CUTLASS 3.x and CuTe + +The most common strategy for implementing high performance convolution kernels on the GPU is to transform +the activation tensor in such a way that we can perform the computation as a GEMM. This is called the +image to column (im2col) transformation. [CUTLASS 2.x implementation of im2col based convolutions is +documented separately](../../media/docs/implicit_gemm_convolution.md), and here we consider a fresh approach for CuTe. + +A 3D convolution has the following input tensors: +- Activation tensor (Act): `((N,(D,H,W)), (C,(1,1,1)))` +- Filter tensor (Flt): `( K, (C,(T,R,S)))` +- Output tensor (Out): `((N,(Z,P,Q)), K )` + +Where +- N := number of images +- DHW := spatial dimensions of the activation tensor +- C := channel dimension of the activation tensor +- K := channel dimension of the filter and output tensor +- TRS := spoke dimensions of the filter tensor +- ZPQ := spatial dimensions of the output tensor + +As is evident in the tensor shapes, these cannot be issued to a GEMM just yet, since there is no +logical M, N, and K modes we can group the tensor modes into. + +Notice that every spoke of the filter tensor (TRS) will be applied to some (offset) view of the +activation tensor, thus expanding the logical size of the activation tensor. +Additionally, a similar logical transform of the spatial dimensions can be encoded as a function of the +padding, dilations, traversal strides, and filter spokes. This gets us to our im2col transform: + +im2col transform affects the component shapes/strides of the activation tensor in the following way: +- ZPQ Shape : changes DHW domain with formula `(1 + (DHW + pad - (((TRS-1) * dilation) + 1)) / traversal_stride)` +- TRS Shape : TRS domain instead of `(1,1,1)` +- ZPQ Strides : Original DHW strides get `elem_scale()`-ed by traversal strides DHW +- TRS Strides : Original DHW strides get `elem_scale()`-ed by dilation DHW + +With this transform applied, we end up with a set of input and output tensors that +are logically consistent in their MNK dimensions, thus allowing us to dispatch to a GEMM. +im2col activation layout: ((N,(Z,P,Q)), (C,(T,R,S))) // logical (M,K) +filter layout : ( K, (C,(T,R,S))) // logical (N,K) +output layout : ((N,(Z,P,Q)), K ) // logical (M,N) + +CuTe's layout representation and algebra make these folded tensors easy to represent and manipulate. +This is most evident in the reference check code used in this example: + +```cpp +for (size_t logical_m = 0; logical_m < size<0>(mOutputRef); ++logical_m) { + for (size_t logical_n = 0; logical_n < size<1>(mOutputRef); ++logical_n) { + auto accumulator = float(0); + for (size_t logical_k = 0; logical_k < size<1>(mStencil); ++logical_k) { + accumulator += mStencil(logical_m, logical_k) * mActivation(logical_n, logical_k); + } + mOutputRef(logical_m, logical_n) = accumulator; + } +} +``` + +Which succinctly demonstrates how im2col transform allows us to implement convolutions +as GEMMs with special layout transformations on the input tensor. + +Note: in the example kernel's implementation we treat activations as the B tensor +and filter as the A tensor, thus making their logical dimensions NK and MK respectively. + +## Leveraging CUTLASS collectives off the shelf in a custom kernel + +Now that we have transformed our problem in such a way that allows us to dispatch to a GEMM, +we can reuse much of the machinery CUTLASS offers to implement this forward pass convolution +operator. CUTLASS decomposes these "moving parts" of GPU linear algebra into reusable, +modular software components abstracted by C++ template classes. This example +demonstrates how some of the lower layers of the hierarchy can be re-used for custom kernels +by writing a custom kernel for convolution that re-uses the Ampere/Ada GEMM collectives +from CUTLASS 3. + +A kernel author is free to compose their custom components with any of the existing templates +in the CUTLASS hierarchy to leverage existing high performance implementations from the CUTLASS +team. In this example, we write a custom kernel layer and compose with an existing collective. +However, any of the CUTLASS kernels can be composed with bespoke collectives if the desired +customization is a mainloop or epilogue fusion without changes to the grid planning, +tile scheduling, load balancing, or thread marshalling. + +## Implementing gather/scatter and dense convolution with the same kernel + +Functionality and correctness of the implemented kernel, as a virtue of using +CuTe and off the shelf CUTLASS collectives, only relies on the logical consistency of +the layouts of input and output tensors. This means that we can freely change how +the logical coordinates of the tensors map into the index space, and even how these dereferences +happen. [CUTLASS example 52](../52_hopper_gather_scatter_fusion/) demonstrates this by implementing a custom stride that +supports indexed indirection for tensor data accesses. This allows for example 52 +to implement a GEMM where inputs are gathered and output is scattered based on an index buffer. + +We re-use the same custom stride utilities in this example to implement a convolution kernel +that gathers along the NDHW dimensions of the activation tensor and scatters the output along the +NZPQ dimensions of the output tensor, treating the channel dimensions as the dense vectors. + +Our dense affine im2col transformed activation tensor: + +```cpp +// im2col transformed activation layout: ((nzpq), (ctrs)) => idx +auto xformed_act_layout = make_layout( + make_shape (make_shape ( N, Z, P, Q), make_shape ( C, T, R, S)), + make_stride(make_stride(D*H*W*C, H*W*C, W*C, C), make_stride(_1{}, H*W*C, W*C, C))); +``` + +now becomes a composed layout that uses `IndexedGather`: + +```cpp +// Inner layout of the composition: +// ((nzpq), (csrt)) => (idx_buffer_idx, dense_offset) +auto EG = E<0>{}; // Gather basis (1,0) (idx_buffer_idx) +auto EC = E<1>{}; // Contiguous basis (0,1) (dense_offset) +auto xformed_act_logical_inner = make_layout( + make_shape (make_shape ( N, Z, P, Q), make_shape ( C, T, R, S)), + make_stride(make_stride(D*H*W*EG, H*W*EG, W*EG, EG), make_stride(EC, H*W*EG, W*EG, EG))); + +// Outer layout of the composition: +// (idx_buffer_idx, dense_offset) => idx +// IndexedGather obtains idx by applying (gmem_base_ptr + gather_idx_buf[idx_buffer_idx] + dense_offset) +auto xformed_act_gather_outer = make_layout( + make_shape(_1{},_1{}), + make_stride(CustomStride{IndexedGather{gather_idx_buf}, C}, _1{})); + +// Compose the inner and outer layouts +// ((nzpq), (ctrs)) => idx +auto xformed_act_composed_layout = composition( + xformed_act_gather_outer, + make_arithmetic_tuple(_0{}, _0{}), + xformed_act_logical_inner); +``` + +Here, we create a composed layout whose inner layout has the same logical MK shape as earlier, +but with an outer layout that uses the custom strides with an index buffer to access memory with +indirections. A custom stride requires two inputs to compute the index that a certain coordinate maps to: +the index buffer offset and the dense offset into the vector. This entails that our inner layout +(the one with the logical MK shape) has a rank-2 codomain `(idx_buffer_idx, dense_offset)`. +We can set up such a layout with scaled basis strides, which allow us to map a domain onto a +codomain with multiple orthogonal bases. The two codomain basis are the +index buffer offsets (rank 0 basis), and the dense vector offsets (rank 1 basis). +A similar composed layout is set up for the output scatter tensor. + +This tensor still has a logical MK shape and is backed by a CuTe layout, which means we can still +tile, partition, and otherwise manipulate it with CuTe's layout algebra in the same way we would any +other tensor. Substituting the activation tensor's affine layout for this gather layout requires +no changes to the implementation of the kernel whatsoever. Everything composes. This example +stamps out a dense 3D convolution as well as gather/scatter 3D convolution using the same kernel template, +with the only difference between them being the layouts of the input and output tensors. + +Convolutions are just a special case of tensor contractions, and as [example 51](../51_hopper_gett) +demonstrates, the exact same collective used in this example can also be used to implement arbitrary GETTs. +Of course, this also means that the same kernel can implement gather/scatter GETTs as well! + +This demonstrates the composition power of not just CuTe, but also CUTLASS 3's two level +micro kernel abstraction. A single highly tuned temporal micro-kernel (collective) can be implemented once +and applied to compute dense GETTs, gather/scatter GETTs, dense convolutions, and gather/scatter convolutions. + +## Peak performance on Ampere and Ada GPUs by leveraging domain specific knowledge + +Often, when implementing custom kernels, a user has more knowledge of the problem domain that can be +exploited to deliver higher performance than otherwise could be through general kernels. In this example +we presume that the shape of each of the images (DHWC dimensions) as well as the filter (TRS) are available +a-priori and that the tile shape evenly divides the problem. Number of images (N) is still left as a runtime +parameter. + +Knowing the extents of our tensors at compile time allows us to encode them as static cute shapes rather than +a dynamic problem shape, resulting in the elimination of most of the index computation instructions such as +expensive div/mods. Knowing that the problem shape is divisible by the tile shape allows us to use the +Ampere collective that does not perform predication on global memory loads, further reducing overheads +and allowing us to achieve near peak performance on RTX Ampere and Ada GPUs. + +Running this example on an RTX 3080Ti prints the following performance numbers (some output culled for brevity): + +``` +$> ./examples/59_ampere_gather_scatter_conv/59_ampere_gather_scatter_conv --n=131072 --i=128 --no-check +Ampere convolution forward propogation kernel supporting both affine and gather/scatter tensors. + +Allocating tensors ... done. +Initializing data ... done. +Initializing gather/scatter index buffers ... done. + +Running dense fprop kernel +Conv TFLOP count = 0.927713 +Conv dense perf: 31.027376ms | TFLOP/s = 29.899819 + +Running gather/scatter fprop kernel +Conv TFLOP count = 0.927713 +Conv gather/scatter perf: 28.973721ms | TFLOP/s = 32.019117 +``` + +With this in mind, this example kernel has the following limitations: +- This example kernel only supports dynamic image count, all other conv problem shape must be defined as `cute::Constant<>`s +- Problem shapes (including dynamic image count `N`) must be evenly divisible by the tile shape +- It does not perform fp32->tf32 numeric conversion, gmem inputs must be rounded to tf32 already diff --git a/examples/59_ampere_gather_scatter_conv/ampere_conv_kernel.h b/examples/59_ampere_gather_scatter_conv/ampere_conv_kernel.h new file mode 100644 index 0000000000..cc00cced96 --- /dev/null +++ b/examples/59_ampere_gather_scatter_conv/ampere_conv_kernel.h @@ -0,0 +1,320 @@ +/*************************************************************************************************** + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/tensor.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/atom/copy_atom.hpp" +#include + +#include "cutlass/util/print_error.hpp" + +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_mma.hpp" + +using namespace cute; + +struct AmpereUnpredicatedFprop { + // + // Static config for conv problem shape + // + using D = _6; + using H = _4; + using W = _4; + + using T = _3; + using R = _3; + using S = _3; + + using Z = _4; + using P = _2; + using Q = _2; + + using C = _64; + using K = _128; + + // Tiler config + using Tiler_K = decltype(cute::min(K{}, _128{})); + using Tiler_C = decltype(cute::min(C{}, _32{})); + using Tiler_N = _4; + using TileM = Tiler_K; + using TileN = Shape; + using TileK = Shape; + using PIPE = _3; + using TilerFlt = Shape; + using TilerAct = Shape; + using TilerOut = Shape; + + using TileSizeM = Int; + using TileSizeN = Int; + using TileSizeK = Int; + static constexpr int Stages = PIPE::value; + + using ElementFlt = tfloat32_t; + using ElementAct = tfloat32_t; + using ElementOut = float; + + using TiledMma = TiledMMA< + MMA_Atom, + Layout>, + Tile<_32,_32,Underscore>>; + + static constexpr int MaxThreadsPerBlock = size(TiledMma{}); + static constexpr int MinBlocksPerMultiprocessor = 1; + + union SharedStorage { + struct { + ElementFlt sAMatrix[size(TileM{}) * size(TileK{}) * size(PIPE{})]; + ElementAct sBMatrix[size(TileN{}) * size(TileK{}) * size(PIPE{})]; + } mainloop; + + struct { + ElementOut sCMatrix[size(TileM{}) * size(TileN{})]; + } epilogue; + }; + + // + // Stencil tensor + // + + using GmemLayoutFlt = decltype(make_ordered_layout( + Shape< K, Shape< C, T, R, S>>{}, + tuple<_4, tuple<_0,_3,_2,_1>>{})); + + // We have 64 elements * 32b each in the major mode that we can vectorize + // Max vector size is 128b, so lay 16 threads along the major mode with a vector size of 4 + // Rest along the minor mode + using GmemTiledCopyFlt = decltype(make_tiled_copy( + Copy_Atom, ElementFlt>{}, + Layout, + Stride< _8, _1>>{}, + Layout>{})); + + // Following layout is also correct, but trades off dynamic strides in the slice for bank conflict free accesses + // using SmemLayoutFlt = decltype( + // composition(Swizzle<3,2,3>{}, + // make_ordered_layout( + // Shape{}, + // tuple< _1, _0, _2>{}))); + + using SmemLayoutAtomFlt = decltype( + composition(Swizzle<1,2,3>{}, + Layout>, + Stride<_4,Stride<_1,_32>>>{})); + + using SmemCopyAtomFlt = Copy_Atom; + + // + // Activation tensor + // + + // Activation tensor is major in the contraction mode, so vectorize that mode first + // Then lay out the rest of the threads along the other mode + using GmemTiledCopyAct = decltype(make_tiled_copy( + Copy_Atom, ElementAct>{}, + Layout, + Stride< _8, _1>>{}, + Layout>{})); + + // Following layout is also correct, but trades off dynamic strides in the slice for bank conflict free accesses + // using SmemLayoutAct = decltype( + // composition(Swizzle<3,2,3>{}, + // make_ordered_layout( + // Shape{}, + // tuple< _1, _0, _2>{}))); + + using SmemLayoutAtomAct = decltype( + composition(Swizzle<1,2,3>{}, + Layout>, + Stride<_4,Stride<_1,_32>>>{})); + + using SmemCopyAtomAct = Copy_Atom; + + // + // Output tensor + // + + using GmemTiledCopyOut = decltype(make_tiled_copy( + Copy_Atom, ElementAct>{}, + Layout, + Stride<_1, _8>>{}, + Layout>{})); + + using SmemCopyAtomOut = Copy_Atom, ElementOut>; + + // This can be optimized to make accesses BCF, but we use a col-major layout here to show off composability + using SmemLayoutOut = Layout>; + + // + // Conv functor + // + template + void __device__ + operator()(cute::Tensor mFlt, // ( K, (C,T,R,S)) + TensorActivation mAct, // ((N,Z,P,Q), (C,T,R,S)) + TensorOutput mOut, // ( K, (N,Z,P,Q)) + char* smem_buf) const { + using namespace cute; + using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveMma< + cutlass::gemm::MainloopSm80CpAsyncUnpredicated, + Shape, + ElementFlt, + Underscore, // Ignore the stride, we are passing full cute::Tensor to operator() + ElementAct, + Underscore, // Ignore the stride, we are passing full cute::Tensor to operator() + TiledMma, + GmemTiledCopyFlt, + SmemLayoutAtomFlt, + SmemCopyAtomFlt, + cute::identity, + GmemTiledCopyAct, + SmemLayoutAtomAct, + SmemCopyAtomAct, + cute::identity>; + + TiledMma tiled_mma; + Tensor accum = partition_fragment_C(tiled_mma, TilerOut{}); + clear(accum); + + // Set up tensors + // NOTE: blockIdx.x projects onto act-NDHW mode, y along the flt-K mode for the sake of higher dynamic range in NDHW + Tensor gA_mk = local_tile(mFlt, TilerFlt{}, make_coord(_,_)); // (BLK_M,BLK_K,m',k') + Tensor gB_nk = local_tile(mAct, TilerAct{}, make_coord(_,_)); // (BLK_N,BLK_K,n',_1) + Tensor gC_mn = local_tile(mOut, TilerOut{}, make_coord(_,_)); // (BLK_M,BLK_N,m',n') + + // Compute m_coord and n_coord with their post-tiled shapes + auto m_coord = idx2crd(int(blockIdx.y), shape<2>(gA_mk)); + auto n_coord = idx2crd(int(blockIdx.x), shape<2>(gB_nk)); + Tensor gA = gA_mk(_,_,m_coord,_); // (BLK_M,BLK_K,k') + Tensor gB = gB_nk(_,_,n_coord,_); // (BLK_N,BLK_K,_1) + Tensor gC = gC_mn(_,_,m_coord,n_coord); // (BLK_M,BLK_N) + + auto k_tile_iter = cute::make_coord_iterator(size<2>(gA)); + int k_tile_count = size<2>(gA); + + CollectiveMainloop collective_mma; + collective_mma( + accum, + gA, + gB, + accum, + k_tile_iter, k_tile_count, + Underscore{}, // no residue since we do not support predication + threadIdx.x, + smem_buf); + + // + // Epilogue + // + SharedStorage& storage = *reinterpret_cast(smem_buf); + Tensor sC = make_tensor(make_smem_ptr(&storage.epilogue.sCMatrix[0]), SmemLayoutOut{}); + + auto smem_tiled_copy_C = make_tiled_copy_C(SmemCopyAtomOut{}, tiled_mma); + auto smem_thr_copy_C = smem_tiled_copy_C.get_slice(threadIdx.x); + auto tCrC = smem_thr_copy_C.retile_S(accum); + auto tCsC = smem_thr_copy_C.partition_D(sC); + copy(smem_tiled_copy_C, tCrC, tCsC); + + __syncthreads(); + + GmemTiledCopyOut gmem_tiled_copy_C; + auto gmem_thr_copy_C = gmem_tiled_copy_C.get_slice(threadIdx.x); + auto tDsC = gmem_thr_copy_C.partition_S(sC); + auto tDgC = gmem_thr_copy_C.partition_D(gC); + copy(gmem_tiled_copy_C, tDsC, tDgC); + + #if 0 + if (thread0()) { + print("mAct = "); print(mAct); print('\n'); + print("mFlt = "); print(mFlt); print('\n'); + print("mOut = "); print(mOut); print('\n'); + print("gA = "); print(gA); print('\n'); + print("gB = "); print(gB); print('\n'); + print("gC = "); print(gC); print('\n'); + print("sA = "); print(sA.layout()); print('\n'); + print("sB = "); print(sB.layout()); print('\n'); + print("sC = "); print(sC.layout()); print('\n'); + print("tAgA = "); print(tAgA.layout()); print('\n'); + print("tBgB = "); print(tBgB.layout()); print('\n'); + print("tAsA = "); print(tAsA.layout()); print('\n'); + print("tBsB = "); print(tBsB.layout()); print('\n'); + print("tCsA = "); print(tCsA.layout()); print('\n'); + print("tCsB = "); print(tCsB.layout()); print('\n'); + print("tCrC = "); print(tCrC.layout()); print('\n'); + print("tCsC = "); print(tCsC.layout()); print('\n'); + print("tDsC = "); print(tDsC.layout()); print('\n'); + print("tDgC = "); print(tDgC.layout()); print('\n'); + print("gmem tiled copy A = "); print(gmem_tiled_copy_A); print('\n'); + print("gmem tiled copy B = "); print(gmem_tiled_copy_B); print('\n'); + print("gmem tiled copy C = "); print(gmem_tiled_copy_C); print('\n'); + print("k_tile_count = "); print(size<2>(gA)); print('\n'); + print("k_tile_iter = "); print(*k_tile_iter); print('\n'); + print("K_BLOCK_MAX = "); print(K_BLOCK_MAX); print('\n'); + } + #endif + } +}; + +template +inline int +fprop_reference( + TensorFlt mStencil, // Logical MK: ( K, (C,T,R,S)) + TensorAct mActivation, // Logical NK: ((N,Z,P,Q), (C,T,R,S)) + TensorOut mOutput, // Logical MN: ( K, (N,Z,P,Q)) + TensorOut mOutputRef) { + int32_t N = size<1,0>(mOutputRef); + int32_t Z = size<1,1>(mOutputRef); + int32_t P = size<1,2>(mOutputRef); + int32_t Q = size<1,3>(mOutputRef); + int32_t T = size<1,3>(mStencil); + int32_t R = size<1,2>(mStencil); + int32_t S = size<1,1>(mStencil); + int32_t C = size<1,0>(mStencil); + + size_t K = static_cast(size<0>(mOutputRef)); + size_t NZPQ = static_cast(size<1>(mOutputRef)); + size_t CTRS = static_cast(size<1>(mStencil)); + +#if defined(_OPENMP) + #pragma omp parallel for +#endif + for (size_t logical_m = 0; logical_m < K; ++logical_m) { + for (size_t logical_n = 0; logical_n < NZPQ; ++logical_n) { + auto accumulator = float(0); + for (size_t logical_k = 0; logical_k < CTRS; ++logical_k) { + accumulator += mStencil(logical_m, logical_k) * mActivation(logical_n, logical_k); + } + mOutputRef(logical_m, logical_n) = accumulator; + } + } + + return print_relative_error(mOutput, mOutputRef, /*print_verbose*/ false, /*print_error*/ true, /*error_margin*/ 0.01); +} diff --git a/examples/59_ampere_gather_scatter_conv/ampere_gather_scatter_conv.cu b/examples/59_ampere_gather_scatter_conv/ampere_gather_scatter_conv.cu new file mode 100644 index 0000000000..341d1e9fd1 --- /dev/null +++ b/examples/59_ampere_gather_scatter_conv/ampere_gather_scatter_conv.cu @@ -0,0 +1,392 @@ +/*************************************************************************************************** + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Example demonstrating CuTe and CUTLASS 3.x based Ampere convolution forward propogation kernel + capable of operating on both affine and gather/scatter tensors. + + This example demonstartes a few super cool features of CUTLASS and CuTe. It shows off + 1. A dense conv 3D fprop kernel written as a single file ... + 2. ... that leverages off the shelf CUTLASS collectives to show how custom kernels can use collectives ... + 3. ... and uses the exact same templated kernel to also stamp out a gather/scatter 3D fprop conv ... + 4. ... while getting near peak performance of the Ampere class tensor core on Ampere and Ada GPUs ... + 5. ... by using static cute shapes and strides in case problem shapes are known at compile time. + + Full documentation for this example can be found within the README.md file in this directory. + + Example executions: + ./59_ampere_gather_scatter_conv + ./59_ampere_gather_scatter_conv --n=108 + ./59_ampere_gather_scatter_conv --n=4096 --i=1 + ./59_ampere_gather_scatter_conv --n=1080 --i=1000 + ./59_ampere_gather_scatter_conv --n=131072 --i=1000 --no-check +*/ + +#include +#include + +#include "ampere_conv_kernel.h" +#include "gather_tensor.hpp" + +#include "cutlass/util/command_line.h" + +bool check_cuda_result(cudaError_t code, const char* file, int line) { + if (code == cudaSuccess) { + return true; + } + + std::cerr << "CUDA error at (" << file << "," << line << ")\n\t" << unsigned(code) << " -- " << cudaGetErrorString(code) << "\n"; + return false; +} + +#define CHECK_CUDA(code) (check_cuda_result(code, __FILE__, __LINE__)) + +using namespace cute; +using example::IndexedGather; +using example::CustomStride; + +template +__global__ +__launch_bounds__(Operator::MaxThreadsPerBlock, Operator::MinBlocksPerMultiprocessor) +void kernel_entrypoint(FilterTensor mFlt, ActivationTensor mAct, OutputTensor mOut) { + extern __shared__ char smem_buf[]; + Operator op; + op(mFlt, mAct, mOut, smem_buf); +} + +int ampere_dense_conv_fprop( + int num_images, + float* activations, + float* filter, + float* output, + float* output_ref, + int num_iterations = 1, + bool do_ref_check = true) { + auto D = typename AmpereUnpredicatedFprop::D{}; + auto H = typename AmpereUnpredicatedFprop::H{}; + auto W = typename AmpereUnpredicatedFprop::W{}; + auto Z = typename AmpereUnpredicatedFprop::Z{}; + auto P = typename AmpereUnpredicatedFprop::P{}; + auto Q = typename AmpereUnpredicatedFprop::Q{}; + auto C = typename AmpereUnpredicatedFprop::C{}; + auto K = typename AmpereUnpredicatedFprop::K{}; + auto S = typename AmpereUnpredicatedFprop::S{}; + auto R = typename AmpereUnpredicatedFprop::R{}; + auto T = typename AmpereUnpredicatedFprop::T{}; + + int N = num_images; // dynamic + if (num_images % int(typename AmpereUnpredicatedFprop::Tiler_N{}) != 0) { + printf("ERROR: Input image count must be evenly divisible by CTA tiler N.\n"); + return 1; + } + + // Tensor Activation: (n,d,h,w,c)::(?,6,4,4,64):(6144,1536,384,64,1) + auto activation_layout = make_layout( + make_shape (make_shape ( N, D, H, W), make_shape ( C, _1{},_1{},_1{})), + make_stride(make_stride(D*H*W*C, H*W*C, W*C, C), make_stride(_1{}, _0{},_0{},_0{}))); + + auto xformed_act_layout = make_layout( + make_shape (make_shape(N, Z, P, Q), make_shape ( C, T, R, S)), + make_stride(stride<0>(activation_layout), make_stride(_1{}, H*W*C, W*C, C))); + + // Tensor Filter : (k,c,s,r,t)::(128,3,3,3,64):(1728,576,192,64,1) + auto filter_layout = AmpereUnpredicatedFprop::GmemLayoutFlt{}; + + // Tensor Output : (n,z,p,q,k)::(?,4,2,2,128):(2048,1024,512,128,1) + auto output_layout = make_ordered_layout( + make_shape( K, make_shape( N, Z, P, Q)), + make_tuple(_0{}, make_tuple(_4{},_3{},_2{},_1{}))); + + Tensor mActivation = make_tensor(make_gmem_ptr(activations), activation_layout); + Tensor mXformedAct = make_tensor(make_gmem_ptr(activations), xformed_act_layout); + Tensor mFilter = make_tensor(make_gmem_ptr(filter), filter_layout); + Tensor mOutput = make_tensor(make_gmem_ptr(output), output_layout); // (K, (N,Z,P,Q)) + Tensor mOutputRef = make_tensor(make_gmem_ptr(output_ref), output_layout); + + print("xformed act layout ((N,Z,P,Q), (C,T,R,S)) = "); print(xformed_act_layout); print("\n"); + + cudaEvent_t start, stop; + CHECK_CUDA(cudaEventCreate(&start)); + CHECK_CUDA(cudaEventCreate(&stop)); + + constexpr size_t smem_size = sizeof(typename AmpereUnpredicatedFprop::SharedStorage); + Tensor gOutput_mn = zipped_divide(mOutput, typename AmpereUnpredicatedFprop::TilerOut{}); // ((BLK_M, BLK_N), (m', n')) + dim3 lauch_grid {static_cast(size<1,1>(gOutput_mn)), static_cast(size<1,0>(gOutput_mn)), 1}; + + CHECK_CUDA(cudaFuncSetAttribute( + kernel_entrypoint, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size)); + + CHECK_CUDA(cudaEventRecord(start)); + for (int i = 0; i < num_iterations; ++i) { + kernel_entrypoint + <<>>( + mFilter, mXformedAct, mOutput); + } + CHECK_CUDA(cudaEventRecord(stop)); + CHECK_CUDA(cudaEventSynchronize(stop)); + + float milliseconds = 0; + cudaEventElapsedTime(&milliseconds, start, stop); + milliseconds /= float(num_iterations); + + double tflop_count = (2 * double(size<0>(xformed_act_layout)) * double(size(filter_layout))) / double(1e12); + double tflops = tflop_count / (double(milliseconds) / double(1e3)); + + printf("Conv TFLOP count = %f\n", tflop_count); + printf("Conv dense perf: %fms | TFLOP/s = %f\n", milliseconds, tflops); + + if (do_ref_check) { + printf("Running host reference check ...\n"); + return fprop_reference(mFilter, mXformedAct, mOutput, mOutputRef); + } + else { + return 0; + } +} + +int ampere_gather_scatter_conv_fprop( + int num_images, + float* activations, + uint32_t *gather_idx_buf, + float* filter, + float* output, + uint32_t *scatter_idx_buf, + int num_iterations = 1) { + auto D = typename AmpereUnpredicatedFprop::D{}; + auto H = typename AmpereUnpredicatedFprop::H{}; + auto W = typename AmpereUnpredicatedFprop::W{}; + auto Z = typename AmpereUnpredicatedFprop::Z{}; + auto P = typename AmpereUnpredicatedFprop::P{}; + auto Q = typename AmpereUnpredicatedFprop::Q{}; + auto C = typename AmpereUnpredicatedFprop::C{}; + auto K = typename AmpereUnpredicatedFprop::K{}; + auto S = typename AmpereUnpredicatedFprop::S{}; + auto R = typename AmpereUnpredicatedFprop::R{}; + auto T = typename AmpereUnpredicatedFprop::T{}; + + int N = num_images; // dynamic + if (N % int(typename AmpereUnpredicatedFprop::Tiler_N{}) != 0) { + printf("ERROR: Input image count must be evenly divisible by CTA tiler N. Got num_images = %d\n", N); + return 1; + } + + // Tensor Filter : (k,c,s,r,t)::(128,3,3,3,64):(1728,576,192,64,1) + auto filter_layout = AmpereUnpredicatedFprop::GmemLayoutFlt{}; + + // Tensor Output : (n,z,p,q,k)::(?,4,2,2,128):(2048,1024,512,128,1) + auto output_layout = make_ordered_layout( + make_shape( K, make_shape( N, Z, P, Q)), + make_tuple(_0{}, make_tuple(_4{},_3{},_2{},_1{}))); + + // Input gather layout + // inner_layout(make_coord((nzpq), (csrt))) => (idx_buffer_idx, dense_c_idx) + auto EG = E<0>{}; // Gather basis (1,0) (idx_buffer_idx) + auto EC = E<1>{}; // Contiguous basis (0,1) (dense_offset) + auto xformed_act_logical_inner = make_layout( + make_shape (make_shape ( N, Z, P, Q), make_shape ( C, T, R, S)), + make_stride(make_stride(D*H*W*EG, H*W*EG, W*EG, EG), make_stride(EC, H*W*EG, W*EG, EG))); + + // outer_layout(make_coord(idx_buffer_idx, dense_c_idx)) => idx + // IndexedGather obtains idx by applying (gmem_base_ptr + gather_idx_buf[idx_buffer_idx] + dense_offset) + auto xformed_act_gather_outer = make_layout( + make_shape(_1{},_1{}), + make_stride(CustomStride{IndexedGather{gather_idx_buf}, C}, _1{})); + + // Compose the inner and outer layouts + // gather_composed(make_coord((nzpq), (csrt))) => idx + auto xformed_act_composed_layout = composition( + xformed_act_gather_outer, + make_arithmetic_tuple(_0{}, _0{}), + xformed_act_logical_inner); + + // Output scatter layout + auto out_basis_stride = make_stride( + E<1>{}, + make_stride(Z*P*Q*E<0>{}, P*Q*E<0>{}, Q*E<0>{}, _1{}*E<0>{})); // -> (crd0, crd1) + auto out_basis_layout = make_layout(shape(output_layout), out_basis_stride); + auto out_scatter_layout = make_layout( + make_shape(_1{},_1{}), + make_stride(CustomStride{IndexedGather{scatter_idx_buf}, K}, _1{})); + auto out_composed_layout = composition( + out_scatter_layout, + make_arithmetic_tuple(_0{},_0{}), + out_basis_layout); + + Tensor mXformedActGather = make_tensor(make_gmem_ptr(activations), xformed_act_composed_layout); + Tensor mFilter = make_tensor(make_gmem_ptr(filter), filter_layout); + Tensor mOutputScatter = make_tensor(make_gmem_ptr(output), out_composed_layout); // (K, (N,Z,P,Q)) + + Tensor gOutput_mn = zipped_divide(mOutputScatter, typename AmpereUnpredicatedFprop::TilerOut{}); // ((BLK_M, BLK_N), (m', n')) + dim3 lauch_grid {static_cast(size<1,1>(gOutput_mn)), static_cast(size<1,0>(gOutput_mn)), 1}; + constexpr size_t smem_size = sizeof(typename AmpereUnpredicatedFprop::SharedStorage); + + print("xforemed gather layout ((N,Z,P,Q), (C,T,R,S)) = "); print(xformed_act_composed_layout); print("\n"); + print("Output scatter layout ( K, (N,Z,P,Q)) = "); print(out_composed_layout); print("\n"); + print("Filter layout ( K, (C,T,R,S)) = "); print(filter_layout); print("\n"); + + CHECK_CUDA(cudaFuncSetAttribute( + kernel_entrypoint, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size)); + + cudaEvent_t start, stop; + CHECK_CUDA(cudaEventCreate(&start)); + CHECK_CUDA(cudaEventCreate(&stop)); + CHECK_CUDA(cudaEventRecord(start)); + for (int i = 0; i < num_iterations; ++i) { + kernel_entrypoint + <<>>( + mFilter, mXformedActGather, mOutputScatter); + } + CHECK_CUDA(cudaEventRecord(stop)); + CHECK_CUDA(cudaEventSynchronize(stop)); + float milliseconds = 0; + cudaEventElapsedTime(&milliseconds, start, stop); + milliseconds /= float(num_iterations); + + double tflop_count = (2 * double(size<0>(xformed_act_logical_inner)) * double(size(filter_layout))) / double(1e12); + double tflops = tflop_count / (double(milliseconds) / double(1e3)); + printf("Conv TFLOP count = %f\n", tflop_count); + printf("Conv gather/scatter perf: %fms | TFLOP/s = %f\n", milliseconds, tflops); + + return 0; +} + +int +main(int argc, char const** argv) { + cutlass::CommandLine cmd(argc, argv); + std::cout << "Ampere convolution forward propogation kernel supporting both affine and gather/scatter tensors.\n\n"; + if (cmd.check_cmd_line_flag("help")) { + std::cout + << "Options:\n" + "\t--n= Sets the number of images for the input activation tensor (dataset size). Default = 131072.\n" + "\t--i= Sets the benchmarking repetitions. Default = 128.\n" + "\t--nocheck If specified, skips the reference check for dense kernel.\n" + "\t--help Displays this help message and exits.\n"; + return 0; + } + + + cudaDeviceProp props; + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + if (props.major < 8) { + std::cerr << "This example requires an Ampere GPU or newer.\n"; + return 0; + } + + int num_images = 4320; + cmd.get_cmd_line_argument("n", num_images, 4320); + int num_iterations = 128; + cmd.get_cmd_line_argument("i", num_iterations, 128); + bool do_host_ref_check = not cmd.check_cmd_line_flag("no-check"); + + auto D = typename AmpereUnpredicatedFprop::D{}; + auto H = typename AmpereUnpredicatedFprop::H{}; + auto W = typename AmpereUnpredicatedFprop::W{}; + auto Z = typename AmpereUnpredicatedFprop::Z{}; + auto P = typename AmpereUnpredicatedFprop::P{}; + auto Q = typename AmpereUnpredicatedFprop::Q{}; + auto C = typename AmpereUnpredicatedFprop::C{}; + auto K = typename AmpereUnpredicatedFprop::K{}; + + auto activation_layout = make_layout( + make_shape (make_shape (num_images, D, H, W), make_shape ( C, _1{},_1{},_1{})), + make_stride(make_stride( D*H*W*C, H*W*C, W*C, C), make_stride(_1{}, _0{},_0{},_0{}))); + + auto filter_layout = typename AmpereUnpredicatedFprop::GmemLayoutFlt{}; + + auto output_layout = make_ordered_layout( + make_shape( K, make_shape(num_images, Z, P, Q)), + make_step (_0{}, make_step ( _4{},_3{},_2{},_1{}))); + + print("Filter layout ( K, (C,T,R,S)) = "); print(filter_layout); print("\n"); + print("Activation layout ((N,D,H,W), (C,1,1,1)) = "); print(activation_layout); print("\n"); + print("Output layout ( K, (N,Z,P,Q)) = "); print(output_layout); print("\n"); + + // allocate tensors + std::cout << "Allocating tensors ... "; + thrust::universal_vector activation_data(size_t(cute::size(activation_layout)), float(0)); + thrust::universal_vector filter_data(size_t(cute::size(filter_layout)), float(0)); + thrust::universal_vector output_data(size_t(cute::size(output_layout)), float(0)); + thrust::universal_vector output_data_ref(size_t(cute::size(output_layout)), float(0)); + std::cout << "done.\n"; + + // init tensors + std::cout << "Initializing data ... " << std::flush; + std::random_device rd; + std::mt19937 gen(rd()); + std::uniform_real_distribution uniform_dist(-1.0, 1.0); + for (std::size_t i = 0; i < size_t(cute::size(activation_layout)); ++i) { + activation_data[i] = uniform_dist(gen); + } + + for (std::size_t i = 0; i < size_t(cute::size(filter_layout)); ++i) { + filter_data[i] = uniform_dist(gen); + } + std::cout << "done.\n"; + + // set up index buffers for gather/scatter, fill with indireciton indices in reversed order + std::cout << "Initializing gather/scatter index buffers ... "; + thrust::universal_vector gather_idx_buf(size_t(size<0>(activation_layout))); + thrust::universal_vector scatter_idx_buf(size_t(size<1>(output_layout))); + thrust::sequence(gather_idx_buf.rbegin(), gather_idx_buf.rend()); + thrust::sequence(scatter_idx_buf.rbegin(), scatter_idx_buf.rend()); + std::cout << "done.\n"; + + // launch dense + std::cout << "\nRunning dense fprop kernel\n"; + int passed = ampere_dense_conv_fprop( + num_images, + activation_data.data().get(), + filter_data.data().get(), + output_data.data().get(), + output_data_ref.data().get(), + num_iterations, + do_host_ref_check); + + // launch gather/scatter + std::cout << "\nRunning gather/scatter fprop kernel\n"; + ampere_gather_scatter_conv_fprop( + num_images, + activation_data.data().get(), + gather_idx_buf.data().get(), + filter_data.data().get(), + output_data.data().get(), + scatter_idx_buf.data().get(), + num_iterations); + + return passed; +} diff --git a/examples/60_cutlass_import/CMakeLists.txt b/examples/60_cutlass_import/CMakeLists.txt new file mode 100644 index 0000000000..974bf4102b --- /dev/null +++ b/examples/60_cutlass_import/CMakeLists.txt @@ -0,0 +1,66 @@ +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# This example demonstrates building against and utilizing an +# installed CUTLASS library. Unlike the other examples, this example +# is not built within the standard CUTLASS CMake flow, but rather +# relies on a pre-installed CUTLASS package. If the CUTLASS package is +# not installed in a standard location, provide the root location of +# the install with "-DCUTLASS_DIR=" CMake +# argument or any of the other features CMake allows for specifying +# locations of installed CMake packages via find_package(). + +cmake_minimum_required(VERSION 3.18 FATAL_ERROR) + +project(cutlass_import_example VERSION 0.2 LANGUAGES CXX CUDA) + +if (CUTLASS_DIR) + message(STATUS "Using CUTLASS specified at ${CUTLASS_DIR}.") + list(APPEND CMAKE_PREFIX_PATH ${CUTLASS_DIR}) +endif() + +find_package(NvidiaCutlass 2.0 REQUIRED) + +message(STATUS "CUTLASS: ${NvidiaCutlass_DIR}") + +add_executable(example) + +target_sources(example PRIVATE main.cpp) + +target_include_directories( + example + PRIVATE + ${CMAKE_CUDA_TOOLKIT_INCLUDE_DIRECTORIES} + ) + +target_link_libraries( + example + PRIVATE + nvidia::cutlass::cutlass + nvidia::cutlass::library + ) diff --git a/examples/60_cutlass_import/main.cpp b/examples/60_cutlass_import/main.cpp new file mode 100644 index 0000000000..f17f545892 --- /dev/null +++ b/examples/60_cutlass_import/main.cpp @@ -0,0 +1,66 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief This example demonstrates utilizing an installed CUTLASS library. +*/ + +#include + +#include +#include + +int main(int argc, char ** argv) { + + // The operations built into the CUTLASS library are managed by a + // Manifest. The manifest is populated with a call to one of the + // "initialize" methods. + + cutlass::library::Manifest manifest; + + initialize_all(manifest); + + // Once initialized, the manifest can be queried for operations, + // and those operations can be further inspected via methods + // exposed in the library headers. + // + // Here, we simply enumerate the embedded kernels and list them. + + auto & opVec = manifest.operations(); + + std::cout << "Manifest contains " << opVec.size() << " operations, listed below." << std::endl; + + for(auto opIter = manifest.begin(); opIter != manifest.end(); ++opIter) { + std::cout << (*opIter)->description().name << std::endl; + } + + return 0; +} diff --git a/examples/61_hopper_gemm_with_topk_and_softmax/61_hopper_gemm_with_topk_and_softmax.cu b/examples/61_hopper_gemm_with_topk_and_softmax/61_hopper_gemm_with_topk_and_softmax.cu new file mode 100644 index 0000000000..8bb14b4556 --- /dev/null +++ b/examples/61_hopper_gemm_with_topk_and_softmax/61_hopper_gemm_with_topk_and_softmax.cu @@ -0,0 +1,534 @@ +/*************************************************************************************************** + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Hopper GEMM + Top-K + Softmax fusion + + This example illustrates how to use the LinCombTopKSoftmaxCol EVT node to fuse + Top-K and Softmax into the GEMM epilogue, with certain assumptions made. + + Those assumptions are as: + 1. Fusion is over the N dimension. + 2. Top-K is either 2 or 4 elements, and the value is static (meaning two kernels have to be + compiled to support both.) + 3. The GEMM tile shape along N is greater than or equal to problem size + along N. + + + The example runs the fused GEMM kernel, along with a standard unfused host reference, and + manually performs Top-K and softmax, and compares the error between tensors. + + Note that some numerical error (smaller than 1e-5) is to be expected, but this is true + in most efficient reduction kernels, because floating point addition is not necessarily + associative. +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/error_metrics.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gett.hpp" + + +#include "helper.h" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +static constexpr int TopK = 2; +static constexpr bool EnableTopKSoftmax = TopK > 1; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// + +// A matrix configuration +using ElementA = cutlass::half_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = cutlass::half_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// C matrix configuration +using ElementC = void; +using LayoutC = cutlass::layout::RowMajor; +constexpr int AlignmentC = 1; + +// D matrix configuration +using ElementD = cutlass::half_t; // Element type for C and D matrix operands +using LayoutD = cutlass::layout::RowMajor; // Layout type for output +constexpr int AlignmentD = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of output in units of elements (up to 16 bytes) + +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ElementCompute = float; // Element type for epilogue computation +using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape = Shape<_64,_64,_128>; // Threadblock-level tile size +using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster +using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecialized; +using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; + +// Top-K + Softmax fusion operation +using FusionOperation = std::conditional_t, + typename cutlass::epilogue::fusion::LinearCombination +>; + +// The fusion op only allows for epilogue tiles matching the mainloop tile. +using EpilogueTileType = decltype(cute::take<0,2>(TileShape{})); + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + TileShape, ClusterShape, + EpilogueTileType, + ElementAccumulator, ElementCompute, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + EpilogueSchedule, + FusionOperation + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage)) + >, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloop, + CollectiveEpilogue +>; + +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + +// Extract information from Gemm kernel. +using EpilogueOutputOp = typename Gemm::EpilogueOutputOp; +using ElementScalar = typename EpilogueOutputOp::ElementScalar; + +using StrideA = typename Gemm::GemmKernel::StrideA; +using StrideB = typename Gemm::GemmKernel::StrideB; +using StrideD = typename Gemm::GemmKernel::StrideD; + +/// Initialization +StrideA stride_A; +StrideB stride_B; +StrideD stride_D; +uint64_t seed; + +cutlass::HostTensor tensor_A; +cutlass::HostTensor tensor_B; +cutlass::HostTensor tensor_D; +cutlass::HostTensor tensor_ref_D; + +using LayoutScalar = cutlass::layout::PackedVectorLayout; + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help = false; + + int iterations = 1000; + int m = 16, n = 8, k = 64, l = 1; + double eps = 1e-5; + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("l", l); + cmd.get_cmd_line_argument("iterations", iterations); + cmd.get_cmd_line_argument("eps", eps); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "61_hopper_gemm_with_topk_and_softmax\n\n" + << " Hopper FP8 GEMM with Top-K and softmax fusion.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the l extent (batch) of the GEMM\n" + << " --iterations= Number of profiling iterations to perform.\n\n" + << " --eps= Threshold of numerical verification. Default: 1e-5.\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "61_hopper_gemm_with_topk_and_softmax" << " --m=16 --n=8 --k=1024 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } + + float alpha() const { + return 1.f / static_cast(k); + } +}; + +/// Result structure +struct Result { + double avg_runtime_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + Result( + double avg_runtime_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess) + : + avg_runtime_ms(avg_runtime_ms), gflops(gflops), status(status), error(error), passed(false) + {} + +}; + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_tensor( + cutlass::TensorView view, + uint64_t seed) { + cutlass::reference::host::TensorFillRandomUniform( + view, seed, /* max = */ 1, /* min = */ -1, /* bits = */ 2); + return true; +} + +/// Initialize operands to be used in the GEMM and reference GEMM +void initialize(const Options &options) { + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, options.l)); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, options.l)); + + auto a_coord = cutlass::make_Coord(options.m * options.l, options.k); + auto c_coord = cutlass::make_Coord(options.m * options.l, options.n); + auto b_coord = cutlass::make_Coord(options.k, options.n * options.l); + + tensor_A.resize(a_coord); + tensor_B.resize(b_coord); + tensor_D.resize(c_coord); + tensor_ref_D.resize(c_coord); + + initialize_tensor(tensor_A.host_view(), seed + 2022); + initialize_tensor(tensor_B.host_view(), seed + 2023); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_D.sync_device(); +} + +/// Populates a Gemm::Arguments structure from the given commandline options +typename Gemm::Arguments args_from_options(const Options &options) { + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {options.m, options.n, options.k, options.l}, + {tensor_A.device_data(), stride_A, tensor_B.device_data(), stride_B}, + { + {options.alpha(), 0.f}, // alpha, beta + nullptr, stride_D, + tensor_D.device_data(), stride_D + } + }; + + return arguments; +} + +bool verify(const Options &options) { + // + // Compute reference output + // + + // Create instantiation for device reference gemm kernel + auto A = cute::make_tensor(tensor_A.host_data(), + cute::make_layout(cute::make_shape(options.m, options.k, options.l), stride_A)); + auto B = cute::make_tensor(tensor_B.host_data(), + cute::make_layout(cute::make_shape(options.n, options.k, options.l), stride_B)); + auto D = cute::make_tensor(tensor_ref_D.host_data(), + cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_D)); + using unused_t = decltype(D); + + cutlass::reference::host::GettMainloopParams mainloop_params{A, B}; + + cutlass::reference::host::GettEpilogueParams< + ElementScalar, + ElementScalar, + ElementAccumulator, + ElementCompute, + unused_t, + decltype(D), + unused_t, // bias + unused_t, // aux + unused_t, // valpha + unused_t // vbeta + > epilogue_params; + + epilogue_params.D = D; + epilogue_params.alpha = options.alpha(); + epilogue_params.beta = 0.f; + + // get reference result + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + + if constexpr (EnableTopKSoftmax) { + // top-K + softmax + for (int i = 0; i < options.m; ++i) { + + // Find Top-K + cutlass::Array top_k; + top_k.fill(-cutlass::platform::numeric_limits::infinity()); + for (int j = 0; j < options.n; ++j) { + auto val = static_cast(tensor_ref_D.host_view().ref().at({i, j})); + for (int top_k_idx = 0; top_k_idx < TopK; ++top_k_idx) { + if (val > top_k[top_k_idx]) { + // Shift down + for (int l = TopK - 1; l > top_k_idx; --l) { + top_k[l] = top_k[l - 1]; + } + top_k[top_k_idx] = val; + break; + } + } + } + + // This formulation of top-K + softmax only works when it is + // guaranteed that none of the top-K elements are repeated! + // If this is the case, the device kernel can also make mistakes, because + // A. Once the top-K values are reduced, and the operation is being applied, + // there is no way to tell repeated elements apart, so none are masked. + // B. The softmax sum of exps will be incorrect (because the repeated elements + // are not repeated in it.) + + ElementAccumulator max = top_k[0]; + ElementAccumulator sum = ElementAccumulator(0.f); + for (int top_k_idx = 0; top_k_idx < TopK; ++top_k_idx) { + sum = sum + cutlass::fast_exp(top_k[top_k_idx] - max); + } + + for (int j=0; j < options.n; ++j) { + auto val = tensor_ref_D.host_view().ref().at({i, j}); + if (val < top_k[TopK - 1]) { + tensor_ref_D.host_view().ref().at({i, j}) = static_cast(0.f); + } else { + // Softmax + auto softmax_val = cutlass::fast_exp(val - max) / sum; + tensor_ref_D.host_view().ref().at({i, j}) = static_cast(softmax_val); + } + } + } + } + + // compare_reference + tensor_D.sync_host(); + + double err = cutlass::reference::host::TensorRelativeErrorMetric( + tensor_D.host_view(), + tensor_ref_D.host_view()); + bool passed = err < options.eps; + + if (options.m <= 32 && options.n <= 32) { + std::cout << "GEMM output:\n" << tensor_D.host_view() << "\n\n"; + std::cout << "Reference output:\n" << tensor_ref_D.host_view() << "\n\n"; + } + + std::cout << " Disposition: " << (passed ? "Passed" : "Failed") << " \t Relative error: " << err << std::endl; + + return passed; +} + +/// Execute a given example GEMM computation +template +int run(Options &options) { + initialize(options); + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run()); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + result.passed = verify(options); + + if (!result.passed) { + exit(-1); + } + + // Run profiling loop + if (options.iterations > 0) { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.run()); + } + timer.stop(); + + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + result.gflops = options.gflops(result.avg_runtime_ms / 1000.0); + + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << result.gflops << std::endl; + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example + // and must have compute capability at least 90. + if (__CUDACC_VER_MAJOR__ < 12) { + std::cerr << "This example requires CUDA 12 or newer.\n"; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (props.major < 9) { + std::cerr + << "This example requires a GPU of NVIDIA's Hopper Architecture or " + << "later (compute capability 90 or greater).\n"; + return 0; + } + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Evaluate CUTLASS kernels + // + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + run(options); +#endif + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/61_hopper_gemm_with_topk_and_softmax/CMakeLists.txt b/examples/61_hopper_gemm_with_topk_and_softmax/CMakeLists.txt new file mode 100644 index 0000000000..7d9160a733 --- /dev/null +++ b/examples/61_hopper_gemm_with_topk_and_softmax/CMakeLists.txt @@ -0,0 +1,32 @@ +# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +cutlass_example_add_executable( + 61_hopper_gemm_with_topk_and_softmax + 61_hopper_gemm_with_topk_and_softmax.cu + ) diff --git a/examples/62_hopper_sparse_gemm/62_hopper_sparse_gemm.cu b/examples/62_hopper_sparse_gemm/62_hopper_sparse_gemm.cu new file mode 100644 index 0000000000..c3f1ce709a --- /dev/null +++ b/examples/62_hopper_sparse_gemm/62_hopper_sparse_gemm.cu @@ -0,0 +1,596 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Hopper Sparse GEMM example. + + This example demonstrates how to construct and run a structured sparse GEMM kernel + on NVIDIA Hopper architecture. + +*/ + +#include + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/epilogue/collective/default_epilogue.hpp" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/collective/collective_builder.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/transform/device/transform_universal_adapter.hpp" +#include "cutlass/transform/kernel/sparse_gemm_compressor.hpp" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/tensor_compare.h" +#include "cutlass/util/reference/device/tensor_fill.h" + +#include "helper.h" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SPARSE_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// + +// A matrix configuration +using ElementA = cutlass::half_t; // Element type for A matrix operand +using LayoutTagA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = cutlass::half_t; // Element type for B matrix operand +using LayoutTagB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// C/D matrix configuration +using ElementC = float; // Element type for C and D matrix operands +using LayoutTagC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using TileShape = Shape<_128,_128,_128>; // Threadblock-level tile size for sparse kernel +using TileShapeRef = Shape<_128,_128, _64>; // Threadblock-level tile size for reference (dense) kernel +using ClusterShape = Shape<_1,_2,_1>; // Shape of the threadblocks in a cluster +using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecialized; // Kernel schedule policy +using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; // Epilogue schedule policy + +using ProblemShape = Shape; + +// Sparse kernel setup + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShape, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutTagC, AlignmentC, + ElementC, LayoutTagC, AlignmentC, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassSparseTensorOp, + ElementA, LayoutTagA, AlignmentA, + ElementB, LayoutTagB, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloop, + CollectiveEpilogue +>; + +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + +// Reference (dense) kernel setup + +using CollectiveEpilogueRef = typename cutlass::epilogue::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + TileShapeRef, ClusterShape, + cutlass::epilogue::collective::EpilogueTileAuto, + ElementAccumulator, ElementAccumulator, + ElementC, LayoutTagC, AlignmentC, + ElementC, LayoutTagC, AlignmentC, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloopRef = typename cutlass::gemm::collective::CollectiveBuilder< + cutlass::arch::Sm90, cutlass::arch::OpClassTensorOp, + ElementA, LayoutTagA, AlignmentA, + ElementB, LayoutTagB, AlignmentB, + ElementAccumulator, + TileShapeRef, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage))>, + KernelSchedule + >::CollectiveOp; + +using GemmKernelRef = cutlass::gemm::kernel::GemmUniversal< + ProblemShape, + CollectiveMainloopRef, + CollectiveEpilogue +>; + +using GemmRef = cutlass::gemm::device::GemmUniversalAdapter; + +// Layouts +using LayoutA = typename Gemm::GemmKernel::CollectiveMainloop::LayoutA; +using LayoutE = typename Gemm::GemmKernel::CollectiveMainloop::LayoutE; +using StrideB = typename Gemm::GemmKernel::StrideB; +using StrideC = typename Gemm::GemmKernel::StrideC; +using StrideD = typename Gemm::GemmKernel::StrideD; + +// Layouts for reference (non-sparse) tensors +using StrideA = cutlass::gemm::TagToStrideA_t; +using StrideE = StrideA; + +using ElementE = typename Gemm::GemmKernel::CollectiveMainloop::ElementE; +using SparseConfig = typename Gemm::GemmKernel::CollectiveMainloop::SparseConfig; + +// Offline compressor kernel +using CompressorUtility = cutlass::transform::kernel::StructuredSparseCompressorUtility< + ProblemShape, + ElementA, + LayoutTagA, + SparseConfig>; + +using CompressorKernel = cutlass::transform::kernel::StructuredSparseCompressor< + ProblemShape, + ElementA, + LayoutTagA, + SparseConfig, + cutlass::arch::Sm90>; + +using Compressor = cutlass::transform::device::TransformUniversalAdapter; + +// +// Data members +// + +ProblemShape problem_shape; + +StrideA stride_A; +StrideA stride_A_compressed; +StrideE stride_E; +StrideB stride_B; +StrideC stride_C; +StrideD stride_D; + +LayoutA layout_A; +LayoutE layout_E; + +uint64_t seed; + +cutlass::DeviceAllocation block_A; +cutlass::DeviceAllocation block_A_compressed; +cutlass::DeviceAllocation block_E; +cutlass::DeviceAllocation block_B; +cutlass::DeviceAllocation block_C; +cutlass::DeviceAllocation block_D; +cutlass::DeviceAllocation block_D_ref; + +#endif // defined(CUTLASS_ARCH_MMA_SPARSE_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + + float alpha, beta; + int iterations; + int m, n, k, l; + + Options(): + help(false), + m(5120), n(4096), k(16384), l(1), + alpha(1.f), beta(0.f), + iterations(10) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("l", l); + cmd.get_cmd_line_argument("alpha", alpha); + cmd.get_cmd_line_argument("beta", beta); + cmd.get_cmd_line_argument("iterations", iterations); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "62_hopper_sparse_gemm\n\n" + << " Hopper Sparse GEMM example.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the L extent of the GEMM (batch size)\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n\n" + << " --iterations= Number of profiling iterations to perform.\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "62_hopper_sparse_gemm" << " --m=4096 --n=5120 --k=8192 --l=1 --alpha=2 --beta=0.707 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } +}; + +#if defined(CUTLASS_ARCH_MMA_SPARSE_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_block( + cutlass::DeviceAllocation& block, + uint64_t seed) { + + Element scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = Element(2); + scope_min = Element(0); + } else if (bits_input <= 8) { + scope_max = Element(2); + scope_min = Element(-2); + } else { + scope_max = Element(8); + scope_min = Element(-8); + } + + cutlass::reference::device::BlockFillRandomUniform( + block.get(), block.size(), seed, scope_max, scope_min, 0); + + return true; +} + +/// Make A structured sparse by replacing elements with 0 and compress it +bool sparsify_and_compress() +{ + auto [M, N, K, L] = problem_shape; + CompressorUtility compressor_utility(problem_shape, stride_A); + + int ME = compressor_utility.get_metadata_m_physical(); + int KE = compressor_utility.get_metadata_k_physical(); + int KC = compressor_utility.get_tensorA_k_physical(); + + block_A_compressed.reset(M * KC * L); + block_E.reset(ME * KE * L); + + stride_A_compressed = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, KC, L)); + stride_E = cutlass::make_cute_packed_stride(StrideE{}, cute::make_shape(ME, KE, L)); + + // Random sparsification is performed on host + std::vector block_A_host(block_A.size()); + cutlass::device_memory::copy_to_host(block_A_host.data(), block_A.get(), block_A.size()); + compressor_utility.structure_sparse_zero_mask_fill(block_A_host.data(), static_cast(seed + 2024)); + cutlass::device_memory::copy_to_device(block_A.get(), block_A_host.data(), block_A.size()); + + cutlass::KernelHardwareInfo hw_info; + hw_info.device_id = 0; + hw_info.sm_count = cutlass::KernelHardwareInfo::query_device_multiprocessor_count(hw_info.device_id); + typename Compressor::Arguments arguments { + problem_shape, + { block_A.get(), + stride_A, + block_A_compressed.get(), + block_E.get() }, + {hw_info} }; + + Compressor compressor_op; + size_t workspace_size = Compressor::get_workspace_size(arguments); + cutlass::device_memory::allocation workspace(workspace_size); + + CUTLASS_CHECK(compressor_op.can_implement(arguments)); + CUTLASS_CHECK(compressor_op.initialize(arguments, workspace.get())); + CUTLASS_CHECK(compressor_op.run()); + CUDA_CHECK(cudaDeviceSynchronize()); + + return true; +} + +/// Initialize operands to be used in the GEMM and reference GEMM +bool initialize(Options const& options) { + + problem_shape = make_tuple(options.m, options.n, options.k, options.l); + auto [M, N, K, L] = problem_shape; + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(M, K, L)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(N, K, L)); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(M, N, L)); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(M, N, L)); + + // Allocate memory for tensors + block_A.reset(M * K * L); + block_B.reset(N * K * L); + block_C.reset(M * N * L); + block_D.reset(M * N * L); + block_D_ref.reset(M * N * L); + + // Fill input tensors with data + initialize_block(block_A, seed + 2021); + initialize_block(block_B, seed + 2022); + initialize_block(block_C, seed + 2023); + + // Replace 0 in A with 1 to avoid metadata changes + std::vector block_A_host(block_A.size()); + cutlass::device_memory::copy_to_host(block_A_host.data(), block_A.get(), block_A.size()); + for (size_t i = 0; i < block_A.size(); ++i) if (block_A_host[i] == ElementA(0)) block_A_host[i] = ElementA(1.0); + cutlass::device_memory::copy_to_device(block_A.get(), block_A_host.data(), block_A.size()); + + if (!sparsify_and_compress()) { + return false; + }; + + // Build the compressed/metadata layouts + layout_A = SparseConfig::fill_layoutA(problem_shape); + layout_E = SparseConfig::fill_layoutE(problem_shape); + + return true; +} + +/// Populates a Gemm::Arguments structure from the given commandline options +typename Gemm::Arguments make_args(Options const& options) +{ + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_shape, + { block_A_compressed.get(), layout_A, block_B.get(), stride_B, block_E.get(), layout_E }, + { { ElementAccumulator(options.alpha), ElementAccumulator(options.beta) }, + block_C.get(), stride_C, block_D.get(), stride_D } + }; + + return arguments; +} + +typename GemmRef::Arguments make_args_ref(Options const& options) +{ + typename GemmRef::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + problem_shape, + { block_A.get(), stride_A, block_B.get(), stride_B }, + { { ElementAccumulator(options.alpha), ElementAccumulator(options.beta) }, + block_C.get(), stride_C, block_D_ref.get(), stride_D } + }; + + return arguments; +} + +template +void print_device_tensor(cute::Tensor const& t) +{ + // Assumes size = cosize, i.e. compact tensor + std::vector data_host(t.size()); + cutlass::device_memory::copy_to_host(data_host.data(), t.data(), t.size()); + auto t_host = cute::make_tensor(data_host.data(), t.layout()); + cute::print_tensor(t_host); +} + +bool verify(Options const& options) { + CUDA_CHECK(cudaDeviceSynchronize()); + + bool passed = cutlass::reference::device::BlockCompareEqual(block_D_ref.get(), block_D.get(), block_D.size()); + +#if 0 + if (!passed) { + auto [M, N, K, L] = problem_shape; + CompressorUtility compressor_utility(problem_shape, stride_A); + int ME = compressor_utility.get_metadata_m_physical(); + int KE = compressor_utility.get_metadata_k_physical(); + int KC = compressor_utility.get_tensorA_k_physical(); + + cute::print("A (original): "); print_device_tensor(make_tensor(block_A.get(), make_shape(M, K, L), stride_A)); + cute::print("A (compressed): "); print_device_tensor(make_tensor(block_A_compressed.get(), make_shape(M, KC, L), stride_A_compressed)); + cute::print("E (physical): "); print_device_tensor(make_tensor(block_E.get(), make_shape(ME, KE, L), stride_E)); + cute::print("E (logical): "); print_device_tensor(make_tensor(block_E.get(), upcast(layout_E))); + cute::print("B: "); print_device_tensor(make_tensor(block_B.get(), make_shape(N, K, L), stride_B)); + cute::print("C: "); print_device_tensor(make_tensor(block_C.get(), make_shape(M, N, L), stride_C)); + cute::print("D reference: "); print_device_tensor(make_tensor(block_D_ref.get(), make_shape(M, N, L), stride_D)); + cute::print("D computed: "); print_device_tensor(make_tensor(block_D.get(), make_shape(M, N, L), stride_D)); + } +#endif + + return passed; +} + +template +struct Runner +{ + using Arguments = typename Gemm::Arguments; + + Runner(Arguments args): arguments(args) { + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + workspace.reset(workspace_size); + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + } + + void run() { + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + CUTLASS_CHECK(gemm.run()); + } + + void benchmark(Options const& options) { + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + run(); + } + timer.stop(); + + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + double avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + double gflops = options.gflops(avg_runtime_ms / 1000.0); + + std::cout << " Avg runtime: " << avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << gflops << std::endl; + } + } + + Gemm gemm; + Arguments arguments; + cutlass::device_memory::allocation workspace; +}; + +/// Execute the example (verification and timing) +void run(Options &options) { + bool init = initialize(options); + if (!init) { + std::cout << "Initialization failure" << std::endl; + exit(EXIT_FAILURE); + } + + Runner gemm(make_args(options)); + Runner gemm_ref(make_args_ref(options)); + + gemm.run(); + gemm_ref.run(); + + bool passed = verify(options); + + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << std::endl; + std::cout << " Disposition: " << (passed ? "Passed" : "Failed") << std::endl; + + if (!passed) { + exit(EXIT_FAILURE); + } + + std::cout << "Sparse GEMM:" << std::endl; + gemm.benchmark(options); + + std::cout << "Dense GEMM:" << std::endl; + gemm_ref.benchmark(options); +} + +#endif // defined(CUTLASS_ARCH_MMA_SPARSE_SM90_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.2 Toolkit to run this example + // and must have compute capability at least 90. + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 2)) { + std::cerr << "This example requires CUDA 12.2 or newer.\n"; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (props.major < 9) { + std::cerr + << "This example requires a GPU of NVIDIA's Hopper Architecture or " + << "later (compute capability 90 or greater).\n"; + return 0; + } + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Evaluate CUTLASS kernels + // + +#if defined(CUTLASS_ARCH_MMA_SPARSE_SM90_SUPPORTED) + run(options); +#endif + + return EXIT_SUCCESS; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/62_hopper_sparse_gemm/CMakeLists.txt b/examples/62_hopper_sparse_gemm/CMakeLists.txt new file mode 100644 index 0000000000..cf55da4552 --- /dev/null +++ b/examples/62_hopper_sparse_gemm/CMakeLists.txt @@ -0,0 +1,36 @@ + +# Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +# Sparse kernel in this example triggers an ICE in gcc 7.5 +if (NOT (CMAKE_CXX_COMPILER_ID STREQUAL "GNU" AND CMAKE_CXX_COMPILER_VERSION VERSION_LESS 8.0)) +cutlass_example_add_executable( + 62_hopper_sparse_gemm + 62_hopper_sparse_gemm.cu + ) +endif() diff --git a/examples/63_hopper_gemm_with_weight_prefetch/63_hopper_gemm_with_weight_prefetch.cu b/examples/63_hopper_gemm_with_weight_prefetch/63_hopper_gemm_with_weight_prefetch.cu new file mode 100644 index 0000000000..03c54a8ee9 --- /dev/null +++ b/examples/63_hopper_gemm_with_weight_prefetch/63_hopper_gemm_with_weight_prefetch.cu @@ -0,0 +1,500 @@ +/*************************************************************************************************** + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Hopper FP8 GEMM + L2 Weight Prefetch + + This example implements a non-persistent warp-specialized GEMM kernel for the Hopper + architecture with programmatic dependent launch (PDL) enabling prefetching weights into + L2 cache. + + For more information about dependent launch refer to the CUDA programming guide: + https://docs.nvidia.com/cuda/cuda-c-programming-guide/index.html#programmatic-dependent-launch-and-synchronization + + In some cases, PDL can result in a window where a previous kernel is not actively utilizing + DRAM, and the next kernel sits idle until the previous finishes. During this window, the next + kernel can begin loading a non-dependent operand (i.e. weights in a linear projection are + typically static) and cache it in L2. + + The kernel and collective mainloop assume operand `A` corresponds to weights and operand `B` + corresponds to activations (so we can have very small batch/token count). + After initialization, the prefetch warp starts loading K tiles of `A` into an unused portion + of shared memory, and loads up to half of all K tiles that the same CTA would eventually load. + The exact number of K tiles loaded is determined by `args.mainloop.prefetch_ratio` \in + [0.0, 1.0]. Smaller values result in less prefetching, and larger values result in more. + Negative values result in a "best-effort" prefetch, meaning prefetcher will stop issuing weight + loads as soon as the activation DMA warp starts loading (as soon as it is signaled that the + previous kernel has flushed its memory.) + + The DMA warp responsible for loading `A` will also begin loading K tiles until it fills up + the available shared memory. + The DMA warp responsible for loading `B` will wait until activations are flushed to global + memory by the preceding kernel. + + Another mainloop parameter, `args.mainloop.overlap_ratio` \in [0.0, 1.0] determines how early + the next kernel (the one doing the prefetch) is launched. Smaller values result in greater + overlap, and larger values result in smaller overlap. Negative values disable PDL completely, + meaning there will be no overlap. This will make prefetch ineffective. + + These two runtime parameters should be tuned per problem size and GEMM config combination, and + if feasible, per-operation in an entire layer or model. + + NOTE: you must build this target with the following flag to enable Grid Dependency Control + instructions (GDC) in CUTLASS: + - CUTLASS_ENABLE_GDC_FOR_SM90 + + To lock persistence mode, power (350W), clocks (1005MHz) for evaluation (assumes device 0 and H100) + + $ sudo nvidia-smi -pm 1 -i 0 + + $ sudo nvidia-smi -i 0 -pl 350 + + $ sudo nvidia-smi -i 0 -lgc 1005 + + Example: + + $ mkdir build && cd build + + $ cmake .. -DCUTLASS_NVCC_ARCHS="90a" -DCUTLASS_ENABLE_GDC_FOR_SM90=1 + + $ cd examples/63_hopper_gemm_with_weight_prefetch + + $ make + + $ ./63_hopper_gemm_with_weight_prefetch --p=0.5 --o=0.5 +*/ + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_types.h" + +#include "cute/tensor.hpp" +#include "cutlass/tensor_ref.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/collective/collective_builder.hpp" +#include "cutlass/gemm/device/gemm_universal_adapter.h" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_builder.hpp" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/packed_stride.hpp" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/reference/host/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_norm.h" +#include "cutlass/util/reference/host/gett.hpp" + + +#include "collective/dispatch_policy_extra.hpp" +#include "collective/builder.hpp" +#include "kernel/sm90_gemm_tma_warpspecialized_with_prefetch.hpp" + +#include "helper.h" +#include "gemm_with_weight_prefetch_commandline.hpp" + +using namespace cute; + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM kernel configurations +///////////////////////////////////////////////////////////////////////////////////////////////// + +// A matrix configuration +using ElementA = cutlass::float_e4m3_t; // Element type for A matrix operand +using LayoutA = cutlass::layout::RowMajor; // Layout type for A matrix operand +constexpr int AlignmentA = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of A matrix in units of elements (up to 16 bytes) + +// B matrix configuration +using ElementB = cutlass::float_e5m2_t; // Element type for B matrix operand +using LayoutB = cutlass::layout::ColumnMajor; // Layout type for B matrix operand +constexpr int AlignmentB = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of B matrix in units of elements (up to 16 bytes) + +// C matrix configuration +using ElementC = cutlass::float_e4m3_t; // Element type for C and D matrix operands +using LayoutC = cutlass::layout::ColumnMajor; // Layout type for C and D matrix operands +constexpr int AlignmentC = 128 / cutlass::sizeof_bits::value; // Memory access granularity/alignment of C matrix in units of elements (up to 16 bytes) + +// D matrix configuration +using ElementD = ElementC; +using LayoutD = LayoutC; +constexpr int AlignmentD = AlignmentC; + +// Core kernel configurations +using ElementAccumulator = float; // Element type for internal accumulation +using ElementCompute = float; // Element type for epilogue computation +using ArchTag = cutlass::arch::Sm90; // Tag indicating the minimum SM that supports the intended feature +using OperatorClass = cutlass::arch::OpClassTensorOp; // Operator class tag +using TileShape = Shape<_64,_64,_128>; // Threadblock-level tile size +// Cluster_N > 1 is not supported yet. +using ClusterShape = Shape<_1,_1,_1>; // Shape of the threadblocks in a cluster +using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccumWithPrefetchAndSplitDMA; +using EpilogueSchedule = cutlass::epilogue::TmaWarpSpecialized; +using EpilogueTileType = cutlass::epilogue::collective::EpilogueTileAuto; + +using CollectiveEpilogue = typename cutlass::epilogue::collective::CollectiveBuilder< + ArchTag, OperatorClass, + TileShape, ClusterShape, + EpilogueTileType, + ElementAccumulator, ElementCompute, + ElementC, LayoutC, AlignmentC, + ElementD, LayoutD, AlignmentD, + EpilogueSchedule + >::CollectiveOp; + +using CollectiveMainloop = typename cutlass::gemm::collective::CollectiveBuilder< + ArchTag, OperatorClass, + ElementA, LayoutA, AlignmentA, + ElementB, LayoutB, AlignmentB, + ElementAccumulator, + TileShape, ClusterShape, + cutlass::gemm::collective::StageCountAutoCarveout< + static_cast(sizeof(typename CollectiveEpilogue::SharedStorage)) + >, + KernelSchedule + >::CollectiveOp; + +using GemmKernel = cutlass::gemm::kernel::GemmUniversal< + Shape, // Indicates ProblemShape + CollectiveMainloop, + CollectiveEpilogue +>; + +using Gemm = cutlass::gemm::device::GemmUniversalAdapter; + +// Extract information from Gemm kernel. +using EpilogueOutputOp = typename Gemm::EpilogueOutputOp; +using ElementScalar = typename EpilogueOutputOp::ElementScalar; + +using StrideA = typename Gemm::GemmKernel::StrideA; +using StrideB = typename Gemm::GemmKernel::StrideB; +using StrideC = typename Gemm::GemmKernel::StrideC; +using StrideD = typename Gemm::GemmKernel::StrideD; + +/// Initialization +StrideA stride_A; +StrideB stride_B; +StrideC stride_C; +StrideD stride_D; +uint64_t seed; + +cutlass::HostTensor tensor_A; +cutlass::HostTensor tensor_B; +cutlass::HostTensor tensor_C; +cutlass::HostTensor tensor_D; +cutlass::HostTensor tensor_ref_D; + +using LayoutScalar = cutlass::layout::PackedVectorLayout; +cutlass::HostTensor scalar_alpha; +cutlass::HostTensor scalar_beta; + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Testbed utility types +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Result structure +struct Result +{ + double avg_runtime_ms; + double gflops; + double eff_bw; + cutlass::Status status; + cudaError_t error; + bool passed; + + Result( + double avg_runtime_ms = 0, + double gflops = 0, + double eff_bw = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess) + : + avg_runtime_ms(avg_runtime_ms), gflops(gflops), eff_bw(eff_bw), status(status), error(error), passed(false) + {} + +}; + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// GEMM setup and evaluation +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Helper to initialize a block of device data +template +bool initialize_tensor( + cutlass::TensorView view, + uint64_t seed) { + + double scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = 2; + scope_min = 0; + } + else if (bits_input <= 8) { + scope_max = 2; + scope_min = -2; + } + else if (bits_output == 16) { + scope_max = 5; + scope_min = -5; + } + else { + scope_max = 8; + scope_min = -8; + } + cutlass::reference::host::TensorFillRandomUniform( + view, seed, scope_max, scope_min, 0); + + return true; +} + +/// Initialize operands to be used in the GEMM and reference GEMM +void initialize(const Options &options) { + + stride_A = cutlass::make_cute_packed_stride(StrideA{}, cute::make_shape(options.m, options.k, options.l)); + stride_B = cutlass::make_cute_packed_stride(StrideB{}, cute::make_shape(options.n, options.k, options.l)); + stride_C = cutlass::make_cute_packed_stride(StrideC{}, cute::make_shape(options.m, options.n, options.l)); + stride_D = cutlass::make_cute_packed_stride(StrideD{}, cute::make_shape(options.m, options.n, options.l)); + + auto a_coord = cutlass::make_Coord(options.m * options.l, options.k); + auto c_coord = cutlass::make_Coord(options.m * options.l, options.n); + auto b_coord = cutlass::make_Coord(options.k, options.n * options.l); + + tensor_A.resize(a_coord); + tensor_B.resize(b_coord); + tensor_C.resize(c_coord); + tensor_D.resize(c_coord); + tensor_ref_D.resize(c_coord); + + initialize_tensor(tensor_A.host_view(), seed + 2022); + initialize_tensor(tensor_B.host_view(), seed + 2023); + initialize_tensor(tensor_C.host_view(), seed + 2024); + + tensor_A.sync_device(); + tensor_B.sync_device(); + tensor_C.sync_device(); + tensor_D.sync_device(); +} + +/// Populates a Gemm::Arguments structure from the given commandline options +typename Gemm::Arguments args_from_options(const Options &options) +{ + typename Gemm::Arguments arguments{ + cutlass::gemm::GemmUniversalMode::kGemm, + {options.m, options.n, options.k, options.l}, + {tensor_A.device_data(), stride_A, tensor_B.device_data(), stride_B}, + { + {}, // epilogue.thread + tensor_C.device_data(), stride_C, + tensor_D.device_data(), stride_D + } + }; + + auto &fusion_args = arguments.epilogue.thread; + fusion_args.alpha = options.alpha; + fusion_args.beta = options.beta; + fusion_args.alpha_ptr = scalar_alpha.device_data(); + fusion_args.beta_ptr = scalar_beta.device_data(); + + arguments.mainloop.overlap_ratio = options.overlap_ratio; + arguments.mainloop.prefetch_ratio = options.prefetch_ratio; + + return arguments; +} + +bool verify(const Options &options) { + // + // Compute reference output + // + + // Create instantiation for device reference gemm kernel + auto A = cute::make_tensor(tensor_A.host_data(), + cute::make_layout(cute::make_shape(options.m, options.k, options.l), stride_A)); + auto B = cute::make_tensor(tensor_B.host_data(), + cute::make_layout(cute::make_shape(options.n, options.k, options.l), stride_B)); + auto C = cute::make_tensor(tensor_C.host_data(), + cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_C)); + auto D = cute::make_tensor(tensor_ref_D.host_data(), + cute::make_layout(cute::make_shape(options.m, options.n, options.l), stride_D)); + using unused_t = decltype(D); + + cutlass::reference::host::GettMainloopParams mainloop_params{A, B}; + + cutlass::reference::host::GettEpilogueParams< + ElementScalar, + ElementScalar, + ElementAccumulator, + ElementCompute, + decltype(C), + decltype(D), + unused_t, // bias + unused_t, // aux + unused_t, // valpha + unused_t // vbeta + > epilogue_params; + + epilogue_params.C = C; + epilogue_params.D = D; + epilogue_params.alpha = options.alpha; + epilogue_params.beta = options.beta; + + // get reference result + cutlass::reference::host::Gemm3x(mainloop_params, epilogue_params); + + // compare_reference + tensor_D.sync_host(); + bool passed = cutlass::reference::host::TensorEquals(tensor_ref_D.host_view(), tensor_D.host_view()); + + return passed; +} + +/// Execute a given example GEMM computation +template +int run(Options &options) +{ + initialize(options); + + // Instantiate CUTLASS kernel depending on templates + Gemm gemm; + + // Create a structure of gemm kernel arguments suitable for invoking an instance of Gemm + auto arguments = args_from_options(options); + + // Using the arguments, query for extra workspace required for matrix multiplication computation + size_t workspace_size = Gemm::get_workspace_size(arguments); + + // Allocate workspace memory + cutlass::device_memory::allocation workspace(workspace_size); + + // Check if the problem size is supported or not + CUTLASS_CHECK(gemm.can_implement(arguments)); + + // Initialize CUTLASS kernel with arguments and workspace pointer + CUTLASS_CHECK(gemm.initialize(arguments, workspace.get())); + + // Correctness / Warmup iteration + CUTLASS_CHECK(gemm.run(nullptr, nullptr, /* launch_with_pdl = */ options.overlap_ratio >= 0)); + + // Check if output from CUTLASS kernel and reference kernel are equal or not + Result result; + result.passed = verify(options); + + std::cout << " Disposition: " << (result.passed ? "Passed" : "Failed") << std::endl; + + if (!result.passed) { + exit(-1); + } + + // Run profiling loop + if (options.iterations > 0) + { + GpuTimer timer; + timer.start(); + for (int iter = 0; iter < options.iterations; ++iter) { + CUTLASS_CHECK(gemm.run(nullptr, nullptr, /* launch_with_pdl = */ options.overlap_ratio >= 0)); + } + timer.stop(); + + // Compute average runtime and GFLOPs. + float elapsed_ms = timer.elapsed_millis(); + result.avg_runtime_ms = double(elapsed_ms) / double(options.iterations); + double avg_runtime_s = (double)(result.avg_runtime_ms / 1000.0); + result.gflops = options.gflops(avg_runtime_s); + result.eff_bw = options.effective_bandwidth(avg_runtime_s, sizeof(ElementA), sizeof(ElementB), sizeof(ElementC), sizeof(ElementD)); + + std::cout << " Problem Size: " << options.m << 'x' << options.n << 'x' << options.k << 'x' << options.l << std::endl; + std::cout << " Avg runtime: " << result.avg_runtime_ms << " ms" << std::endl; + std::cout << " GFLOPS: " << result.gflops << std::endl; + std::cout << " Effective bandwidth: " << result.eff_bw << " GB/s" << std::endl; + } + + return 0; +} + +#endif // defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + // CUTLASS must be compiled with CUDA 12.0 Toolkit to run this example + // and must have compute capability at least 90. + if (__CUDACC_VER_MAJOR__ < 12) { + std::cerr << "This example requires CUDA 12 or newer.\n"; + // Returning zero so this test passes on older Toolkits. Its actions are no-op. + return 0; + } + + cudaDeviceProp props; + int current_device_id; + CUDA_CHECK(cudaGetDevice(¤t_device_id)); + CUDA_CHECK(cudaGetDeviceProperties(&props, current_device_id)); + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (props.major < 9) { + std::cerr + << "This example requires a GPU of NVIDIA's Hopper Architecture or " + << "later (compute capability 90 or greater).\n"; + return 0; + } + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + // + // Evaluate CUTLASS kernels + // + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + run(options); +#endif + + return 0; +} diff --git a/examples/63_hopper_gemm_with_weight_prefetch/CMakeLists.txt b/examples/63_hopper_gemm_with_weight_prefetch/CMakeLists.txt new file mode 100644 index 0000000000..f48673241a --- /dev/null +++ b/examples/63_hopper_gemm_with_weight_prefetch/CMakeLists.txt @@ -0,0 +1,36 @@ +# Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +include_directories( + . +) + +cutlass_example_add_executable( + 63_hopper_gemm_with_weight_prefetch + 63_hopper_gemm_with_weight_prefetch.cu + ) diff --git a/examples/63_hopper_gemm_with_weight_prefetch/README.md b/examples/63_hopper_gemm_with_weight_prefetch/README.md new file mode 100644 index 0000000000..5dac1cc6c2 --- /dev/null +++ b/examples/63_hopper_gemm_with_weight_prefetch/README.md @@ -0,0 +1,82 @@ +# GEMM with L2 weight prefetch + +A non-persistent warp specialized GEMM directed at low latency inference. + +The kernel can optionally prefetch a portion of weights (operand `A`) into L2 cache while the +rest of the warps are waiting on the previous kernel to finish writing and flush its memory. +An example of this is normalization or reduction kernels that are immediately followed by a GEMM. + +It exposes two runtime parameters: +1. `overlap_ratio`: how early `griddepcontrol.launch_dependent_grids` is issued. + Default is `0.5`, meaning after approximately half of K tiles are loaded by DMA warps. +2. `prefetch_ratio`: what percentage of K tiles to prefetch. + Default is `-1.0`, meaning prefetching will stop as soon as other DMA warps are past + `griddepcontrol`. + +It is highly recommended to auto-tune these parameters per GEMM and according to some end to end +runtime (either an entire transformer layer or multiple, but probably not the entire model.) + +TMA loads use non-default cache hints: `A` (weights) are loaded with `EvictFirst`, and `B` (activation) +is loaded with `EvictLast`. + +## Getting started +To use this kernel in your own target, add this directory to your includes, and include the +following headers from this example: + +```cxx +#include "collective/dispatch_policy_extra.hpp" +#include "collective/builder.hpp" +#include "kernel/sm90_gemm_tma_warpspecialized_with_prefetch.hpp" +``` + +And then use either one of the new kernel schedules: + +```cxx +// Without separate warps for A and B +using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccumWithPrefetch; + +// With separate warps for A and B +using KernelSchedule = cutlass::gemm::KernelTmaWarpSpecializedFP8FastAccumWithPrefetchAndSplitDMA; +``` + +The kernel with separate warps for A and B ( +`KernelTmaWarpSpecializedFP8FastAccumWithPrefetchAndSplitDMA`) +is expected to be more performant than the other, especially since it allows the kernel to load +weights into shmem ahead of the `griddepcontrol`. + +As for other GEMM parameters, Thread Block Cluster larger than 1 CTA are not yet supported, and +obviously the kernel layer implementation is warp specialized and uses the TMA, and other kernel +layers or collectives require reimplementation. + +## Example + +Using the example is mostly straightforward. +Just build, and run with your choice of `MNK`: + +```bash +./63_hopper_gemm_with_weight_prefetch --m=8192 --n=1 --k=8192 +``` + +You can also disable the overlap or try different overlap and prefetch ratios and see the +difference: + +```bash +echo "Without overlap and prefetch" +./63_hopper_gemm_with_weight_prefetch --o=-1.0 --p=-1.0 + +echo "Overlap ratio of 0.5, best effort prefetch" +./63_hopper_gemm_with_weight_prefetch --o=0.5 --p=-1.0 + +echo "Overlap ratio of 0.8, prefetch ratio of 0.7" +./63_hopper_gemm_with_weight_prefetch --o=0.8 --p=0.7 +``` + +However, note that the example still runs a single GEMM, and most of the performance improvement +is expected in end to end applications. + + +## Limitations +* The parameter defaults are typically not good choices, especially `prefetch_ratio`. + When `prefetch_ratio` is unspecified (set to `-1.0`), the prefetch warp will `try_wait` on a + memory barrier before issuing every single TMA load, and in many cases this will slow down + prefetching to the point of being almost ineffective. diff --git a/examples/63_hopper_gemm_with_weight_prefetch/collective/builder.hpp b/examples/63_hopper_gemm_with_weight_prefetch/collective/builder.hpp new file mode 100644 index 0000000000..bfb64820f0 --- /dev/null +++ b/examples/63_hopper_gemm_with_weight_prefetch/collective/builder.hpp @@ -0,0 +1,242 @@ +/*************************************************************************************************** + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/gemm/collective/collective_builder.hpp" + +#include "dispatch_policy_extra.hpp" +#include "sm90_mma_tma_gmma_ss_warpspecialized_with_prefetch.hpp" +#include "../pipeline/prefetch_pipeline_sm90.hpp" + +namespace cutlass::gemm::collective { + +namespace detail { + +// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. +template +constexpr int +compute_stage_count_or_override_prefetch(StageCount stage_count) { + return stages; +} + +// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. +template +constexpr int +compute_stage_count_or_override_prefetch(StageCountAutoCarveout stage_count) { + constexpr auto mainloop_pipeline_bytes = sizeof(typename cutlass::PipelineTmaAsync<1>::SharedStorage); + constexpr auto prefetch_pipeline_bytes = sizeof(typename cutlass::detail::PrefetcherPipelineSharedStorage); + constexpr auto a_bits = cute::sizeof_bits_v; + constexpr auto b_bits = cute::sizeof_bits_v; + constexpr int MK_bytes = cutlass::bits_to_bytes(a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})); //also the prefetch smem size + constexpr int NK_bytes = cutlass::bits_to_bytes(b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})); + constexpr int stage_bytes = MK_bytes + NK_bytes + static_cast(mainloop_pipeline_bytes); + + return (CapacityBytes - carveout_bytes - MK_bytes * PrefetchStagesActual - prefetch_pipeline_bytes) / stage_bytes; +} + +} // namespace detail + +// GMMA_TMA_WS_FP8_FAST_ACCUM_SS + prefetch +template < + class ElementA, + class GmemLayoutATag, + int AlignmentA, + class ElementB, + class GmemLayoutBTag, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class KernelScheduleType +> +struct CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + ElementA, + GmemLayoutATag, + AlignmentA, + ElementB, + GmemLayoutBTag, + AlignmentB, + ElementAccumulator, + TileShape_MNK, + ClusterShape_MNK, + StageCountType, + KernelScheduleType, + cute::enable_if_t< + cute::is_same_v> +> { + static_assert(is_static::value); + static_assert(is_static::value); + static_assert(detail::is_aligned(), + "Not meet TMA alignment requirement yet\n"); + static_assert(detail::is_input_fp8(), + "Only FP8 datatypes are compatible with these kernel schedules\n"); + // Dispatch TN fp8 kernels only to TMA warp specialized FP8 builder + static_assert(!detail::is_use_rmem_A(), + "Not supported for fp8 non-TN warp specialized kernels yet\n"); +#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED + static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); +#endif + + static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A(); + static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B(); + + using AtomLayoutMNK = Layout>; + + using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector< + ElementA, ElementB, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(), AtomLayoutMNK{})); + + using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); + using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); + + using SmemLayoutAtomA = decltype(detail::ss_smem_selector< + GmmaMajorA, ElementA, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutAtomB = decltype(detail::ss_smem_selector< + GmmaMajorB, ElementB, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + + static constexpr int PipelineStages = detail::compute_stage_count_or_override_prefetch(StageCountType{}); + using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedWithPrefetch; + + using SmemCopyAtomA = void; + using SmemCopyAtomB = void; + + using CollectiveOp = CollectiveMma< + DispatchPolicy, + TileShape_MNK, + ElementA, + TagToStrideA_t, + ElementB, + TagToStrideB_t, + TiledMma, + GmemTiledCopyA, + SmemLayoutAtomA, + SmemCopyAtomA, + cute::identity, + GmemTiledCopyB, + SmemLayoutAtomB, + SmemCopyAtomB, + cute::identity + >; +}; + +// GMMA_TMA_WS_FP8_FAST_ACCUM_SS + prefetch and split DMA warps +template < + class ElementA, + class GmemLayoutATag, + int AlignmentA, + class ElementB, + class GmemLayoutBTag, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class KernelScheduleType +> +struct CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + ElementA, + GmemLayoutATag, + AlignmentA, + ElementB, + GmemLayoutBTag, + AlignmentB, + ElementAccumulator, + TileShape_MNK, + ClusterShape_MNK, + StageCountType, + KernelScheduleType, + cute::enable_if_t< + cute::is_same_v> +> { + static_assert(is_static::value); + static_assert(is_static::value); + static_assert(detail::is_aligned(), + "Not meet TMA alignment requirement yet\n"); + static_assert(detail::is_input_fp8(), + "Only FP8 datatypes are compatible with these kernel schedules\n"); + // Dispatch TN fp8 kernels only to TMA warp specialized FP8 builder + static_assert(!detail::is_use_rmem_A(), + "Not supported for fp8 non-TN warp specialized kernels yet\n"); +#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED + static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); +#endif + + static constexpr cute::GMMA::Major GmmaMajorA = detail::gmma_ss_tag_to_major_A(); + static constexpr cute::GMMA::Major GmmaMajorB = detail::gmma_ss_tag_to_major_B(); + + using AtomLayoutMNK = Layout>; + + using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector< + ElementA, ElementB, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(), AtomLayoutMNK{})); + + using GmemTiledCopyA = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<1>(ClusterShape_MNK{}))); + using GmemTiledCopyB = decltype(detail::sm90_cluster_shape_to_tma_atom(shape<0>(ClusterShape_MNK{}))); + + using SmemLayoutAtomA = decltype(detail::ss_smem_selector< + GmmaMajorA, ElementA, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutAtomB = decltype(detail::ss_smem_selector< + GmmaMajorB, ElementB, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + + static constexpr int PipelineStages = detail::compute_stage_count_or_override_prefetch(StageCountType{}); + using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedWithPrefetch; + + using SmemCopyAtomA = void; + using SmemCopyAtomB = void; + + using CollectiveOp = CollectiveMma< + DispatchPolicy, + TileShape_MNK, + ElementA, + TagToStrideA_t, + ElementB, + TagToStrideB_t, + TiledMma, + GmemTiledCopyA, + SmemLayoutAtomA, + SmemCopyAtomA, + cute::identity, + GmemTiledCopyB, + SmemLayoutAtomB, + SmemCopyAtomB, + cute::identity + >; +}; + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/63_hopper_gemm_with_weight_prefetch/collective/dispatch_policy_extra.hpp b/examples/63_hopper_gemm_with_weight_prefetch/collective/dispatch_policy_extra.hpp new file mode 100644 index 0000000000..37369176f9 --- /dev/null +++ b/examples/63_hopper_gemm_with_weight_prefetch/collective/dispatch_policy_extra.hpp @@ -0,0 +1,61 @@ +/*************************************************************************************************** + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +namespace cutlass::gemm { + +// Standard non-persistent kernel with a single producer warp, and one prefetch warp. +// `A` is assumed to be static, and therefore the producer warp for `A` attempts to load `A` +// while the producer warp is waiting on griddepcontrol. +// GDC `launch_dependent_grids` is issued from the producer warp instead of math warps, and +// according to prefetch ratio. +struct KernelTmaWarpSpecializedFP8FastAccumWithPrefetch { }; + +// Non-persistent kernel with two producer warps (one for each of A and B), and one prefetch warp. +// `A` is assumed to be static, and therefore the producer warp for `A` attempts to load `A` +// while the producer warp for `B` is waiting on griddepcontrol. Producer warp for `A` does not +// wait on griddepcontrol and loads immediately. +struct KernelTmaWarpSpecializedFP8FastAccumWithPrefetchAndSplitDMA { }; + +template< + int Stages_, + class ClusterShape_ = Shape<_1,_1,_1>, + class KernelSchedule = KernelTmaWarpSpecializedFP8FastAccumWithPrefetch +> +struct MainloopSm90TmaGmmaWarpSpecializedWithPrefetch { + constexpr static int Stages = Stages_; + using ClusterShape = ClusterShape_; + using ArchTag = arch::Sm90; + using Schedule = KernelSchedule; +}; + +} // namespace cutlass::gemm diff --git a/examples/63_hopper_gemm_with_weight_prefetch/collective/sm90_mma_tma_gmma_ss_warpspecialized_with_prefetch.hpp b/examples/63_hopper_gemm_with_weight_prefetch/collective/sm90_mma_tma_gmma_ss_warpspecialized_with_prefetch.hpp new file mode 100644 index 0000000000..9bcb1f5a7e --- /dev/null +++ b/examples/63_hopper_gemm_with_weight_prefetch/collective/sm90_mma_tma_gmma_ss_warpspecialized_with_prefetch.hpp @@ -0,0 +1,872 @@ +/*************************************************************************************************** + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/numeric_types.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/trace.h" + +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cute/algorithm/functional.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/algorithm/gemm.hpp" +#include "cute/tensor_predicate.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" +#include "cutlass/arch/grid_dependency_control.h" + +#include "dispatch_policy_extra.hpp" + +#include "../pipeline/prefetch_pipeline_sm90.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +constexpr int PrefetchStages = 4; +constexpr int PrefetchInitialStages = 1; +// This determines how much shmem we set aside for prefetch. +// We don't reuse anything loaded by prefetcher, so we can keep +// loading into the same place -- there will be a conflict when +// writing, but it doesn't affect performance as much as the doors +// that this opens. +constexpr int PrefetchStagesActual = 1; + +} // namespace detail + +// WarpSpecialized Mainloop +template < + int Stages, + class ClusterShape, + class KernelSchedule, + class TileShape_, + class ElementA_, + class StrideA_, + class ElementB_, + class StrideB_, + class TiledMma_, + class GmemTiledCopyA_, + class SmemLayoutAtomA_, + class SmemCopyAtomA_, + class TransformA_, + class GmemTiledCopyB_, + class SmemLayoutAtomB_, + class SmemCopyAtomB_, + class TransformB_> +struct CollectiveMma< + MainloopSm90TmaGmmaWarpSpecializedWithPrefetch, + TileShape_, + ElementA_, + StrideA_, + ElementB_, + StrideB_, + TiledMma_, + GmemTiledCopyA_, + SmemLayoutAtomA_, + SmemCopyAtomA_, + TransformA_, + GmemTiledCopyB_, + SmemLayoutAtomB_, + SmemCopyAtomB_, + TransformB_> +{ + // + // Type Aliases + // + using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedWithPrefetch; + using TileShape = TileShape_; + using ElementA = ElementA_; + using StrideA = StrideA_; + using ElementB = ElementB_; + using StrideB = StrideB_; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = GmemTiledCopyA_; + using GmemTiledCopyB = GmemTiledCopyB_; + using SmemLayoutAtomA = SmemLayoutAtomA_; + using SmemLayoutAtomB = SmemLayoutAtomB_; + using SmemCopyAtomA = SmemCopyAtomA_; + using SmemCopyAtomB = SmemCopyAtomB_; + using TransformA = TransformA_; + using TransformB = TransformB_; + using ArchTag = typename DispatchPolicy::ArchTag; + + static_assert(size<1>(ClusterShape{}) == 1, "Cluster shape N must be 1"); + using CtaShape_MNK = decltype(shape_div(TileShape{}, ClusterShape{})); + + using PrefetcherPipeline = cutlass::PrefetchPipeline; + + using MainloopPipeline = cutlass::PipelineTmaAsync; + using PipelineState = cutlass::PipelineState; + using PipelineParams = typename MainloopPipeline::Params; + + static_assert(cute::rank(SmemLayoutAtomA{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<0>(TileShape{}) % size<0>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomA{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + static_assert(cute::rank(SmemLayoutAtomB{}) == 2, "SmemLayoutAtom must be rank 2 (M/N, K)"); + static_assert((size<1>(TileShape{}) % size<0>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + static_assert((size<2>(TileShape{}) % size<1>(SmemLayoutAtomB{})) == 0, "SmemLayoutAtom must evenly divide tile shape."); + + // Tile along modes in a way that maximizes the TMA box size. + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{}), Int{}), + cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideA>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{}), Int{}), + cute::conditional_t< ::cutlass::gemm::detail::is_major<0,StrideB>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{})); + + static_assert(rank(SmemLayoutA{}) == 3 && size<2>(SmemLayoutA{}) == DispatchPolicy::Stages); + static_assert(rank(SmemLayoutB{}) == 3 && size<2>(SmemLayoutB{}) == DispatchPolicy::Stages); + + using PrefetchSmemLayoutA = decltype(make_layout(make_shape( + cute::Int(SmemLayoutA{})>{}, + cute::Int(SmemLayoutA{})>{}, + cute::Int{}))); + + static constexpr auto prefetch_smem_size = cute::cosize_v; + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 2 or more."); + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + static_assert(cute::is_same_v || cute::is_same_v, + "GmemTiledCopy - invalid SM90 TMA copy atom specified."); + + // TMA converts f32 input to tf32 when copying from GMEM to SMEM + // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. + static constexpr bool ConvertF32toTF32A = cute::is_same_v; + static constexpr bool ConvertF32toTF32B = cute::is_same_v; + using InternalElementA = cute::conditional_t>>; + using InternalElementB = cute::conditional_t>>; + + // Defined outside the class where it's used, to work around MSVC issues + using PrefetcherPipelineStorage = ::cutlass::detail::PrefetcherPipelineSharedStorage; + + struct SharedStorage { + struct TensorStorage : cute::aligned_struct<128, _0> { + cute::array_aligned> smem_A; + cute::array_aligned> smem_B; + cute::array_aligned smem_prefetch; + } tensors; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + PrefetcherPipelineStorage prefetcher_pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Host side kernel arguments + struct Arguments { + ElementA const* ptr_A; + StrideA dA; + ElementB const* ptr_B; + StrideB dB; + uint32_t mma_promotion_interval = 4; + float overlap_ratio = 0.5; + float prefetch_ratio = -1.0; + }; + + // Device side kernel params + struct Params { + // Assumption: StrideA is congruent with Problem_MK + using TMA_A = decltype(make_tma_copy_A_sm90( + GmemTiledCopyA{}, + make_tensor(static_cast(nullptr), repeat_like(StrideA{}, int32_t(0)), StrideA{}), + SmemLayoutA{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{})); + // Assumption: StrideB is congruent with Problem_NK + using TMA_B = decltype(make_tma_copy_B_sm90( + GmemTiledCopyB{}, + make_tensor(static_cast(nullptr), repeat_like(StrideB{}, int32_t(0)), StrideB{}), + SmemLayoutB{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{})); + + TMA_A tma_load_a; + TMA_B tma_load_b; + uint32_t tma_transaction_bytes = TmaTransactionBytesMK + TmaTransactionBytesNK; + uint32_t tma_transaction_bytes_mk = TmaTransactionBytesMK; + uint32_t tma_transaction_bytes_nk = TmaTransactionBytesNK; + float overlap_ratio = 0.5; + float prefetch_ratio = -1.0; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + (void) workspace; + + // Optionally append 1s until problem shape is rank-4 (MNKL), in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + auto ptr_A = reinterpret_cast(args.ptr_A); + auto ptr_B = reinterpret_cast(args.ptr_B); + + Tensor tensor_a = make_tensor(ptr_A, make_layout(make_shape(M,K,L), args.dA)); + Tensor tensor_b = make_tensor(ptr_B, make_layout(make_shape(N,K,L), args.dB)); + + typename Params::TMA_A tma_load_a = make_tma_copy_A_sm90( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{}); + typename Params::TMA_B tma_load_b = make_tma_copy_B_sm90( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,cute::Int<0>{}), + TileShape{}, + ClusterShape{}); + uint32_t transaction_bytes_mk = TmaTransactionBytesMK; + uint32_t transaction_bytes_nk = TmaTransactionBytesNK; + uint32_t transaction_bytes = transaction_bytes_mk + transaction_bytes_nk; + + return { + tma_load_a, + tma_load_b, + transaction_bytes, + transaction_bytes_mk, + transaction_bytes_nk, + args.overlap_ratio, + args.prefetch_ratio + }; + } + + template + static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + constexpr int tma_alignment_bits = 128; + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; + bool implementable = cutlass::detail::check_alignment(cute::make_shape(M,K,L), StrideA{}); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(N,K,L), StrideB{}); + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + return false; + } + + if (args.overlap_ratio > 1.0) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: `overlap_ratio` must be either negative (disabled) or in [0, 1].\n"); + return false; + } + + if (args.prefetch_ratio > 1.0) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: `prefetch_ratio` must be either negative (disabled) or in [0, 1].\n"); + return false; + } + + return true; + } + + static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; + static constexpr int K_PIPE_MMAS = 1; + static constexpr uint32_t TmaTransactionBytesMK = + cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(sizeof_bits::value)); + static constexpr uint32_t TmaTransactionBytesNK = + cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(sizeof_bits::value)); + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& mainloop_params) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); + } + + /// Set up the data needed by this collective for load and mma. + /// Returns a tuple of tensors. The collective and the kernel layer have the contract + /// Returned tuple must contain at least two elements, with the first two elements being: + /// gA_mkl - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k,l) + /// gB_nkl - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k,l) + /// The rest of the tensors can be specified as needed by this collective. + template + CUTLASS_DEVICE auto + load_init(ProblemShape_MNKL const& problem_shape_MNKL, Params const& mainloop_params) const { + using X = Underscore; + // Separate out problem shape for convenience + auto [M,N,K,L] = problem_shape_MNKL; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA_mkl = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M,K,L)); // (m,k,l) + Tensor mB_nkl = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K,L)); // (n,k,l) + + // Make tiled views, defer the slice + Tensor gA_mkl = local_tile(mA_mkl, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k,l) + Tensor gB_nkl = local_tile(mB_nkl, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k,l) + + return cute::make_tuple(gA_mkl, gB_nkl); + } + + template < + class TensorA, class TensorB, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + load( + Params const& mainloop_params, + MainloopPipeline pipeline, + PrefetcherPipeline prefetcher_pipeline, + PipelineState smem_pipe_write, + TensorA const& gA_mkl, + TensorB const& gB_nkl, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + int lane_predicate = cute::elect_one_sync(); + + if (lane_predicate) { + bool disable_gdc = mainloop_params.overlap_ratio < 0.0; + float overlap_ratio = mainloop_params.overlap_ratio; + int launch_dep_grids_threshold = static_cast(static_cast(k_tile_count - 1) * overlap_ratio); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // + // Prepare the TMA loads for A + // + + constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + auto cta_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + auto cta_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + + // Applies the mapping from cta_tma_a + Tensor tAgA = cta_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = cta_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + // Applies the mapping from cta_tma_b + Tensor tBgB = cta_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = cta_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + uint16_t mcast_mask_a = 0; + uint16_t mcast_mask_b = 0; + + // Issue TmaLoads + // Maps the tile -> block, value + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); + } + } + + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); + } + } + + // We have to wait on dependent grids because of B. + cutlass::arch::wait_on_dependent_grids(); + + // Signal prefetcher to stop + prefetcher_pipeline.producer_arrive(); + + bool launch_dep_grids = false; + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for (int cnt=0 ; k_tile_count > 0; --k_tile_count, ++cnt) { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + int write_stage = smem_pipe_write.index(); + copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a, cute::TMA::CacheHintSm90::EVICT_FIRST), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b, cute::TMA::CacheHintSm90::EVICT_LAST), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + ++k_tile_iter; + + if (!disable_gdc && cnt >= launch_dep_grids_threshold && !launch_dep_grids) { + launch_dep_grids = true; + cutlass::arch::launch_dependent_grids(); + } + + // Advance smem_pipe_write + ++smem_pipe_write; + } + if (!disable_gdc && !launch_dep_grids) { + cutlass::arch::launch_dependent_grids(); + } + } + } + + template < + class TensorA, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + load_MK( + Params const& mainloop_params, + MainloopPipeline pipeline, + PrefetcherPipeline prefetcher_pipeline, + PipelineState smem_pipe_write, + TensorA const& gA_mkl, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + int lane_predicate = cute::elect_one_sync(); + + if (lane_predicate) { + bool disable_gdc = mainloop_params.overlap_ratio < 0.0; + float overlap_ratio = mainloop_params.overlap_ratio; + int launch_dep_grids_threshold = static_cast(static_cast(k_tile_count - 1) * overlap_ratio); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + + // + // Prepare the TMA loads for A + // + + constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + auto cta_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + + // Applies the mapping from cta_tma_a + Tensor tAgA = cta_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = cta_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + uint16_t mcast_mask_a = 0; + + // Issue TmaLoads + // Maps the tile -> block, value + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); + } + } + + // Don't wait on dependent grids when loading `A`, because + // we assume `A` (weights) are static. + + bool launch_dep_grids = false; + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for (int cnt=0 ; k_tile_count > 0; --k_tile_count, ++cnt) { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + int write_stage = smem_pipe_write.index(); + copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a, cute::TMA::CacheHintSm90::EVICT_FIRST), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + ++k_tile_iter; + + if (!disable_gdc && cnt >= launch_dep_grids_threshold && !launch_dep_grids) { + launch_dep_grids = true; + cutlass::arch::launch_dependent_grids(); + } + + // Advance smem_pipe_write + ++smem_pipe_write; + } + if (!disable_gdc && !launch_dep_grids) { + cutlass::arch::launch_dependent_grids(); + } + } + } + + template < + class TensorB, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + load_NK( + Params const& mainloop_params, + MainloopPipeline pipeline, + PrefetcherPipeline prefetcher_pipeline, + PipelineState smem_pipe_write, + TensorB const& gB_nkl, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + int lane_predicate = cute::elect_one_sync(); + + if (lane_predicate) { + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // + // Prepare the TMA loads for B + // + + constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + auto cta_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gB = gB_nkl(_,_,n_coord,_,l_coord); // (BLK_N,BLK_K,k) + + // Applies the mapping from cta_tma_b + Tensor tBgB = cta_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = cta_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + uint16_t mcast_mask_b = 0; + + // Issue TmaLoads + // Maps the tile -> block, value + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); + } + } + + // Ensure that the prefetched kernel does not touch + // unflushed global memory prior to this instruction + cutlass::arch::wait_on_dependent_grids(); + + // Signal prefetcher to stop + prefetcher_pipeline.producer_arrive(); + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for (; k_tile_count > 0; --k_tile_count) { + // LOCK smem_pipe_write for _writing_ + pipeline.producer_acquire(smem_pipe_write); + + // + // Copy gmem to smem for *k_tile_iter + // + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_write); + + int write_stage = smem_pipe_write.index(); + copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b, cute::TMA::CacheHintSm90::EVICT_LAST), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + ++k_tile_iter; + + // Advance smem_pipe_write + ++smem_pipe_write; + } + } + } + + /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_write) { + int lane_predicate = cute::elect_one_sync(); + + // Issue the epilogue waits + if (lane_predicate) { + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all + * Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was + * still inverted from make_producer_start_state + */ + pipeline.producer_tail(smem_pipe_write); + } + } + + + template < + class TensorA, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + prefetch_MK( + Params const& mainloop_params, + PrefetcherPipeline prefetcher_pipeline, + PipelineState smem_pipe_write, + TensorA const& gA_mkl, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + int lane_predicate = cute::elect_one_sync(); + + if (lane_predicate) { + bool do_best_effort_prefetch = mainloop_params.prefetch_ratio < 0; + float prefetch_ratio = do_best_effort_prefetch ? 1.0 : mainloop_params.prefetch_ratio; + int prefetch_iters = static_cast(static_cast(k_tile_count) * 0.5 * prefetch_ratio); + prefetch_iters = min(k_tile_count, ((prefetch_iters + detail::PrefetchStages - 1) / detail::PrefetchStages) * detail::PrefetchStages); + + Tensor sA = make_tensor( + make_smem_ptr(shared_tensors.smem_prefetch.data()), PrefetchSmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + + // + // Prepare the TMA loads for A + // + + constexpr uint32_t cluster_shape_x = get<0>(typename DispatchPolicy::ClusterShape()); + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + + auto cta_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + Tensor gA = gA_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + + // Applies the mapping from cta_tma_a + Tensor tAgA = cta_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = cta_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + uint16_t mcast_mask_a = 0; + + // Issue TmaLoads + // Maps the tile -> block, value + if constexpr (cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); + } + } + + uint32_t prefetcher_stage = 0; + uint32_t prefetcher_phase = 0; + CUTLASS_PRAGMA_NO_UNROLL + for (int cnt = 0 ; cnt < prefetch_iters; ++cnt) { + + if (do_best_effort_prefetch && prefetcher_pipeline.have_producers_arrived()) { + break; + } + + prefetcher_pipeline.prefetcher_acquire(prefetcher_stage, prefetcher_phase, cnt >= detail::PrefetchStages); + using BarrierType = typename PrefetcherPipeline::PrefetcherBarrierType; + BarrierType* tma_barrier = prefetcher_pipeline.prefetcher_get_barrier(prefetcher_stage); + + int write_stage = 0; + copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a, cute::TMA::CacheHintSm90::EVICT_FIRST), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + ++k_tile_iter; + ++k_tile_iter; + + prefetcher_pipeline.advance_prefetcher_state(prefetcher_stage, prefetcher_phase); + } + prefetcher_pipeline.prefetcher_tail(prefetcher_stage, prefetcher_phase); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template < + class FrgTensorC + > + CUTLASS_DEVICE void + mma(MainloopPipeline pipeline, + PipelineState smem_pipe_read, + FrgTensorC& accum, + int k_tile_count, + int thread_idx, + TensorStorage& shared_tensors, + Params const& mainloop_params) { + static_assert(is_rmem::value, "C tensor must be rmem resident."); + static_assert(cute::rank(SmemLayoutA{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::rank(SmemLayoutB{}) == 3, "Smem layout must be rank 3."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + static_assert(cute::is_void_v, + "SM90 GMMA mainloops cannot have a non-void copy atom for smem sourced instructions."); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // + // Define C accumulators and A/B partitioning + // + + TiledMma tiled_mma; + auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + + Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + // Allocate "fragments/descriptors" + Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + + // + // PIPELINED MAIN LOOP + // + static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), + "ERROR : Incorrect number of MMAs in flight"); + + // We release buffers to producer warps(dma load) with some mmas in flight + PipelineState smem_pipe_release = smem_pipe_read; + + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + + warpgroup_fence_operand(accum); + CUTLASS_PRAGMA_UNROLL + for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue) + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + int read_stage = smem_pipe_read.index(); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + + warpgroup_commit_batch(); + + ++smem_pipe_read; + } + + warpgroup_fence_operand(accum); + // Mainloop GMMAs + k_tile_count -= prologue_mma_count; + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) + { + // WAIT on smem_pipe_read until its data are available (phase bit flips from rdPhaseBit value) + auto barrier_token = pipeline.consumer_try_wait(smem_pipe_read); + pipeline.consumer_wait(smem_pipe_read, barrier_token); + + // + // Compute on k_tile + // + + int read_stage = smem_pipe_read.index(); + warpgroup_fence_operand(accum); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_commit_batch(); + + /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_write is consumed + warpgroup_wait(); + warpgroup_fence_operand(accum); + + // UNLOCK smem_pipe_release, done _computing_ on it + pipeline.consumer_release(smem_pipe_release); + + // Advance smem_pipe_read and smem_pipe_release + ++smem_pipe_read; + ++smem_pipe_release; + } + + warpgroup_fence_operand(accum); + } + + /// Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void + mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) { + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + k_tile_count -= prologue_mma_count; + + smem_pipe_release.advance(k_tile_count); + + // Wait on all GMMAs to complete + warpgroup_wait<0>(); + + for (int count = 0; count < prologue_mma_count; ++count) { + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/63_hopper_gemm_with_weight_prefetch/gemm_with_weight_prefetch_commandline.hpp b/examples/63_hopper_gemm_with_weight_prefetch/gemm_with_weight_prefetch_commandline.hpp new file mode 100644 index 0000000000..6be87768ee --- /dev/null +++ b/examples/63_hopper_gemm_with_weight_prefetch/gemm_with_weight_prefetch_commandline.hpp @@ -0,0 +1,117 @@ +/*************************************************************************************************** + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +// Command line options parsing +struct Options { + + bool help = false; + + float alpha = 1.f, beta = 0.f; + float overlap_ratio = 0.5f, prefetch_ratio = 0.5f; + int iterations = 1000; + int n = 64, m = 1280, k = 8192, l = 1; + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("m", m); + cmd.get_cmd_line_argument("n", n); + cmd.get_cmd_line_argument("k", k); + cmd.get_cmd_line_argument("l", l); + cmd.get_cmd_line_argument("alpha", alpha, 1.f); + cmd.get_cmd_line_argument("beta", beta, 0.f); + cmd.get_cmd_line_argument("p", prefetch_ratio, 0.5f); + cmd.get_cmd_line_argument("o", overlap_ratio, 0.5f); + cmd.get_cmd_line_argument("iterations", iterations); + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "63_hopper_gemm_with_weight_prefetch\n\n" + << " Hopper FP8 GEMM using a non-persistent kernel with L2 weight prefetch. \n" + << " For more details please refer to the source file.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement\n\n" + << " --m= Sets the M extent of the GEMM\n" + << " --n= Sets the N extent of the GEMM\n" + << " --k= Sets the K extent of the GEMM\n" + << " --l= Sets the l extent (batch) of the GEMM\n" + << " --alpha= Epilogue scalar alpha\n" + << " --beta= Epilogue scalar beta\n" + << " --p= Prefetch ratio\n" + << " --o= Overlap ratio\n" + << " --iterations= Number of profiling iterations to perform.\n\n"; + + out + << "\n\nExamples:\n\n" + << "$ " << "63_hopper_gemm_with_weight_prefetch" << + " --m=1024 --n=512 --k=1024 --o=0.5 --p=0.5 \n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const + { + // Two flops per multiply-add + uint64_t flop = uint64_t(2) * m * n * k * l; + double gflop = double(flop) / double(1.0e9); + return gflop / runtime_s; + } + + /// Compute effective bandwidth in GB/sec + double effective_bandwidth( + double runtime_s, + size_t bytes_a, + size_t bytes_b, + size_t bytes_c, + size_t bytes_d + ) const + { + static double const kBytesPerGiB = double(1ull << 30); + + double bytes_in = + (double)(l) * (double)(m) * (double)(k) * (double)(bytes_a) + // A + (double)(l) * (double)(n) * (double)(k) * (double)(bytes_b) + // B + (beta != 0.f ? (double)(l) * (double)(m) * (double)(n) * (double)(bytes_c) : 0.f); // C + double bytes_out = (double)(l) * (double)(m) * (double)(n) * (double)(bytes_d); // D + + double gb_total = (bytes_in + bytes_out) / kBytesPerGiB; + return gb_total / runtime_s; + } +}; diff --git a/examples/63_hopper_gemm_with_weight_prefetch/kernel/sm90_gemm_tma_warpspecialized_with_prefetch.hpp b/examples/63_hopper_gemm_with_weight_prefetch/kernel/sm90_gemm_tma_warpspecialized_with_prefetch.hpp new file mode 100644 index 0000000000..6e33d8fc62 --- /dev/null +++ b/examples/63_hopper_gemm_with_weight_prefetch/kernel/sm90_gemm_tma_warpspecialized_with_prefetch.hpp @@ -0,0 +1,561 @@ +/*************************************************************************************************** + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/kernel_hardware_info.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cutlass/arch/reg_reconfig.h" +#include "cutlass/arch/mma_sm90.h" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/gemm/kernel/sm90_tile_scheduler.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/trace.h" + +#include "cute/tensor.hpp" + +#include "../collective/dispatch_policy_extra.hpp" + +/////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::kernel { + +/////////////////////////////////////////////////////////////////////////////// + +// GEMM + Prefetch for the A tensor + (optional) split DMA warps +template < + class ProblemShape_, + class CollectiveMainloop_, + class CollectiveEpilogue_, + class TileScheduler_ +> +class GemmUniversal< + ProblemShape_, + CollectiveMainloop_, + CollectiveEpilogue_, + TileScheduler_, + cute::enable_if_t< + cute::is_same_v || + cute::is_same_v + > +> +{ +public: + // + // Type Aliases + // + using ProblemShape = ProblemShape_; + static_assert(cute::rank(ProblemShape{}) == 3 or cute::rank(ProblemShape{}) == 4, + "ProblemShape{} should be or "); + static constexpr bool IsGdcEnabled = cutlass::arch::IsGdcGloballyEnabled; + + static constexpr bool SplitWarps = cute::is_same_v; + + // Mainloop derived types + using CollectiveMainloop = CollectiveMainloop_; + using TileShape = typename CollectiveMainloop::TileShape; + using TiledMma = typename CollectiveMainloop::TiledMma; + using ArchTag = typename CollectiveMainloop::ArchTag; + using ElementA = typename CollectiveMainloop::ElementA; + using StrideA = typename CollectiveMainloop::StrideA; + using ElementB = typename CollectiveMainloop::ElementB; + using StrideB = typename CollectiveMainloop::StrideB; + using DispatchPolicy = typename CollectiveMainloop::DispatchPolicy; + using ElementAccumulator = typename CollectiveMainloop::ElementAccumulator; + using ClusterShape = typename DispatchPolicy::ClusterShape; + using MainloopArguments = typename CollectiveMainloop::Arguments; + using MainloopParams = typename CollectiveMainloop::Params; + static_assert(ArchTag::kMinComputeCapability >= 90); + + // Epilogue derived types + using CollectiveEpilogue = CollectiveEpilogue_; + using ElementC = typename CollectiveEpilogue::ElementC; + using StrideC = typename CollectiveEpilogue::StrideC; + using ElementD = typename CollectiveEpilogue::ElementD; + using StrideD = typename CollectiveEpilogue::StrideD; + using EpilogueArguments = typename CollectiveEpilogue::Arguments; + using EpilogueParams = typename CollectiveEpilogue::Params; + + static_assert(cute::is_void_v or cute::is_same_v, + "TMA warp-specialized kernel does not support specializing the tile scheduler."); + using TileSchedulerTag = TileScheduler_; + using TileScheduler = typename detail::TileSchedulerSelector< + TileScheduler_, ArchTag, TileShape, ClusterShape>::Scheduler; + using TileSchedulerArguments = typename TileScheduler::Arguments; + + // Kernel level shared memory storage + struct SharedStorage { + // Mainloop and epilogue don't use smem concurrently since kernel is non-persistent, so we can use a union + union TensorStorage { + using MainloopTensorStorage = typename CollectiveMainloop::TensorStorage; + using EpilogueTensorStorage = typename CollectiveEpilogue::TensorStorage; + + MainloopTensorStorage mainloop; + EpilogueTensorStorage epilogue; + } tensors; + + struct PipelineStorage : cute::aligned_struct<16, _1> { + using MainloopPipelineStorage = typename CollectiveMainloop::PipelineStorage; + using PrefetcherPipelineStorage = typename CollectiveMainloop::PrefetcherPipelineStorage; + using EpiLoadPipelineStorage = typename CollectiveEpilogue::PipelineStorage; + + alignas(16) MainloopPipelineStorage mainloop; + alignas(16) EpiLoadPipelineStorage epi_load; + alignas(16) PrefetcherPipelineStorage prefetcher; + } pipelines; + }; + + static constexpr int SharedStorageSize = sizeof(SharedStorage); + + static constexpr uint32_t NumLoadWarpGroups = 1; + static constexpr uint32_t NumMmaWarpGroups = 1; + static constexpr uint32_t MaxThreadsPerBlock = CUTE_STATIC_V(size(TiledMma{})) + (NumLoadWarpGroups * NumThreadsPerWarpGroup); + static constexpr uint32_t MinBlocksPerMultiprocessor = 1; + + // Device side arguments + struct Arguments { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopArguments mainloop{}; + EpilogueArguments epilogue{}; + KernelHardwareInfo hw_info{}; + TileSchedulerArguments scheduler{}; + }; + + // Kernel entry point API + struct Params { + GemmUniversalMode mode{}; + ProblemShape problem_shape{}; + MainloopParams mainloop{}; + EpilogueParams epilogue{}; + }; + + // + // Methods + // + + // Convert to underlying arguments. In this case, a simple copy for the aliased type. + static + Params + to_underlying_arguments(Arguments const& args, void* workspace) { + (void) workspace; + auto problem_shape = args.problem_shape; + if constexpr (detail::Has_SwapAB_v) { + // swap M/N + get<0>(problem_shape) = get<1>(args.problem_shape); + get<1>(problem_shape) = get<0>(args.problem_shape); + } + return { + args.mode, + problem_shape, + CollectiveMainloop::to_underlying_arguments(args.problem_shape, args.mainloop, workspace), + CollectiveEpilogue::to_underlying_arguments(args.problem_shape, args.epilogue, workspace) + }; + } + + static bool + can_implement(Arguments const& args) { + bool implementable = (args.mode == GemmUniversalMode::kGemm) or + (args.mode == GemmUniversalMode::kBatched && cute::rank(ProblemShape{}) == 4); + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Arguments or Problem Shape don't meet the requirements.\n"); + return implementable; + } + implementable &= CollectiveMainloop::can_implement(args.problem_shape, args.mainloop); + implementable &= CollectiveEpilogue::can_implement(args.problem_shape, args.epilogue); + implementable &= TileScheduler::can_implement(args.scheduler); + + return implementable; + } + + static + size_t + get_workspace_size(Arguments const& args) { + return 0; + } + + static + cutlass::Status + initialize_workspace(Arguments const& args, void* workspace = nullptr, cudaStream_t stream = nullptr, + CudaHostAdapter* cuda_adapter = nullptr) { + return Status::kSuccess; + } + + // Computes the kernel launch grid shape based on runtime parameters + static dim3 + get_grid_shape(Params const& params) { + auto cluster_shape = ClusterShape{}; + auto tile_shape = TileShape{}; + auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); + return TileScheduler::get_tiled_cta_shape_mnl( + problem_shape_MNKL, tile_shape, cluster_shape); + } + + static dim3 + get_block_shape() { + return dim3(MaxThreadsPerBlock, 1, 1); + } + + CUTLASS_DEVICE + void + operator()(Params const& params, char* smem_buf) { + using namespace cute; + using X = Underscore; + +#if defined(__CUDA_ARCH_FEAT_SM90_ALL) +# define ENABLE_SM90_KERNEL_LEVEL 1 +#endif + +// Any Tensor Op MMA Atom in the WGMMA ISA is arch conditional to sm90a. +#if ! defined(ENABLE_SM90_KERNEL_LEVEL) + printf("ERROR : Arch conditional MMA instruction used without targeting sm90a compute capability. Aborting.\n"); +#else + + enum class WarpGroupRole { + Producer = 0, + Consumer = 1, + }; + // Split mode: use Warp0 to load NK and epilogue, Warp2 to load MK. + // Non-split mode: use Warp0 to load MK, NK and epilogue, Warp2 is unused. + // Both modes use Warp1 to prefetch. + enum class ProducerWarpRole { + Warp0 = 0, + PrefetchMK = 1, + Warp2 = 2, + UnusedWarp = 3 + }; + + // Kernel level shared memory storage + SharedStorage& shared_storage = *reinterpret_cast(smem_buf); + + int thread_idx = int(threadIdx.x); + int lane_idx = canonical_lane_idx(); + int warp_idx = canonical_warp_idx_sync(); + int warp_idx_in_warp_group = warp_idx % NumWarpsPerWarpGroup; + int warp_group_thread_idx = thread_idx % NumThreadsPerWarpGroup; + auto warp_group_role = WarpGroupRole(canonical_warp_group_idx()); + auto producer_warp_role = ProducerWarpRole(warp_idx_in_warp_group); + int lane_predicate = cute::elect_one_sync(); + uint32_t block_rank_in_cluster = cute::block_rank_in_cluster(); + + + // Issue Tma Descriptor Prefetch from a single thread + if ((warp_idx == 0) && lane_predicate) { + CollectiveMainloop::prefetch_tma_descriptors(params.mainloop); + CollectiveEpilogue::prefetch_tma_descriptors(params.epilogue); + } + + // Mainloop Load pipeline + using MainloopPipeline = typename CollectiveMainloop::MainloopPipeline; + typename MainloopPipeline::Params mainloop_pipeline_params; + mainloop_pipeline_params.is_leader = warp_group_thread_idx == 0; + if (warp_group_role == WarpGroupRole::Producer && ( + producer_warp_role == ProducerWarpRole::Warp0 || + producer_warp_role == ProducerWarpRole::Warp2)) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Producer; + mainloop_pipeline_params.transaction_bytes = params.mainloop.tma_transaction_bytes; + } + if (warp_group_role == WarpGroupRole::Consumer) { + mainloop_pipeline_params.role = MainloopPipeline::ThreadCategory::Consumer; + } + mainloop_pipeline_params.num_consumers = NumThreadsPerWarpGroup; + MainloopPipeline mainloop_pipeline(shared_storage.pipelines.mainloop, mainloop_pipeline_params, ClusterShape{}); + bool should_prefetch = params.mainloop.prefetch_ratio > 0; + using PrefetcherPipeline = typename CollectiveMainloop::PrefetcherPipeline; + typename PrefetcherPipeline::Params prefetcher_pipeline_params; + prefetcher_pipeline_params.num_prefetchers = 1; + if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::PrefetchMK) { + prefetcher_pipeline_params.should_prefetch = should_prefetch; + prefetcher_pipeline_params.transaction_bytes = params.mainloop.tma_transaction_bytes_mk; + } + PrefetcherPipeline prefetcher_pipeline(shared_storage.pipelines.prefetcher, prefetcher_pipeline_params); + + // Epilogue Load pipeline + using EpiLoadPipeline = typename CollectiveEpilogue::LoadPipeline; + typename EpiLoadPipeline::Params epi_load_pipeline_params; + if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::Warp0) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Producer; + } + if (warp_group_role == WarpGroupRole::Consumer) { + epi_load_pipeline_params.role = EpiLoadPipeline::ThreadCategory::Consumer; + } + epi_load_pipeline_params.dst_blockid = cute::block_rank_in_cluster(); + epi_load_pipeline_params.producer_arv_count = NumThreadsPerWarp; + epi_load_pipeline_params.consumer_arv_count = NumThreadsPerWarpGroup; + if constexpr (CollectiveEpilogue::RequiresTransactionBytes) { + epi_load_pipeline_params.transaction_bytes = params.epilogue.tma_transaction_bytes; + } + EpiLoadPipeline epi_load_pipeline(shared_storage.pipelines.epi_load, epi_load_pipeline_params); + + // Epilogue Store pipeline + using EpiStorePipeline = typename CollectiveEpilogue::StorePipeline; + typename EpiStorePipeline::Params epi_store_pipeline_params; + epi_store_pipeline_params.always_wait = true; + EpiStorePipeline epi_store_pipeline(epi_store_pipeline_params); + + // Initialize starting pipeline states for the collectives + // Epilogue store pipe is producer-only (consumer is TMA unit, waits via scoreboarding) + typename CollectiveMainloop::PipelineState mainloop_pipe_consumer_state; + typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_consumer_state; + + // For the DMA Load (producer) we start with an opposite phase + // i.e., we skip all waits since we know that the buffer is indeed empty + PipelineState mainloop_pipe_producer_state = cutlass::make_producer_start_state(); + PipelineState epi_load_pipe_producer_state = cutlass::make_producer_start_state(); + PipelineState epi_store_pipe_producer_state = cutlass::make_producer_start_state(); + + auto cluster_wait_fn = [&] () { + // We need this to guarantee that the Pipeline init is visible + // To all producers and consumer thread blocks in the Cluster + if constexpr (size(ClusterShape{}) > 1) { + // Non-prefetcher warps arrive and wait, + // Prefetcher warp can go ahead without waiting. + cute::cluster_arrive_relaxed(); + if (warp_group_role != WarpGroupRole::Producer || + producer_warp_role != ProducerWarpRole::PrefetchMK) { + cute::cluster_wait(); + } + return [] () {}; + } + else { + // __syncthreads() but only for non prefetcher warps + if (should_prefetch) { + + // Use a named barrier to let the prefetcher warp start loading into the L2 + // without waiting to sync with all other warps. + // All other warps need to sync because the mainloop pipeline init + // should be visible to all of them. + // Prefetcher has its own barriers, and the only warps it would need to sync + // with would be the DMA warps. + using ClusterSyncWithPrefetchBarrier = typename cutlass::arch::NamedBarrier; + auto prefetcher_arrive_barrier = ClusterSyncWithPrefetchBarrier( + blockDim.x * blockDim.y * blockDim.z, + /*reserved_named_barriers_*/ 14); + // Prefetcher warp doesn't arrive on this barrier. + auto cluster_arrive_barrier = ClusterSyncWithPrefetchBarrier( + blockDim.x * blockDim.y * blockDim.z - NumThreadsPerWarp, + /*reserved_named_barriers_*/ 15); + + if (warp_group_role == WarpGroupRole::Producer && producer_warp_role == ProducerWarpRole::PrefetchMK) { + __syncwarp(); + prefetcher_arrive_barrier.arrive(); + } + else if (warp_group_role == WarpGroupRole::Producer) { + prefetcher_arrive_barrier.arrive_and_wait(); + cluster_arrive_barrier.arrive_and_wait(); + } + else { + prefetcher_arrive_barrier.arrive(); + cluster_arrive_barrier.arrive_and_wait(); + } + } else { + __syncthreads(); + } + return [] () {}; + } + } (); + + // Preconditions + static_assert(cute::rank(StrideA{}) == 3, "StrideA must be rank-3: [M, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideB{}) == 3, "StrideB must be rank-3: [N, K, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]. If batch mode is not needed, set L stride to Int<0>."); + + // Optionally append 1s until problem shape is rank-4 in case it is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(params.problem_shape, Int<1>{}); + + // Get the appropriate blocks for this thread block -- potential for thread block locality + auto blk_shape = TileShape{}; // (BLK_M,BLK_N,BLK_K) + TiledMma tiled_mma; + + // In a warp specialized kernel, collectives expose data movement and compute operations separately + CollectiveMainloop collective_mainloop; + CollectiveEpilogue collective_epilogue(params.epilogue, shared_storage.tensors.epilogue); + + // Prepare and partition the input tensors. Expects a tuple of tensors where: + // get<0>(load_inputs) is the tma tensor A after local tiling so that it has shape (BLK_M,BLK_K,m,k,l) + // get<1>(load_inputs) is the tma tensor B after local tiling so that it has shape (BLK_N,BLK_K,n,k,l) + auto load_inputs = collective_mainloop.load_init(problem_shape_MNKL, params.mainloop); + static_assert(cute::tuple_size_v >= 2, "Output of load_init must have at least two elements (A, B)"); + + // Extract out partitioned A and B. + Tensor gA_mkl = get<0>(load_inputs); + Tensor gB_nkl = get<1>(load_inputs); + + // Compute m_coord, n_coord, and l_coord with their post-tiled shapes + auto m_coord = idx2crd(int(blockIdx.x), shape<2>(gA_mkl)); + auto n_coord = idx2crd(int(blockIdx.y), shape<2>(gB_nkl)); + auto l_coord = idx2crd(int(blockIdx.z), shape<4>(gB_nkl)); + auto blk_coord = make_coord(m_coord, n_coord, _, l_coord); + + // Get pipeline iterators and increments from tensor shapes + auto k_tile_iter = cute::make_coord_iterator(shape<3>(gA_mkl)); + auto k_tile_count = size<3>(gA_mkl); + + // Wait for all thread blocks in the Cluster + cluster_wait_fn(); + + if (warp_group_role == WarpGroupRole::Producer) { + if (producer_warp_role == ProducerWarpRole::Warp0) { + if constexpr(SplitWarps) { + collective_mainloop.load_NK( + params.mainloop, + mainloop_pipeline, + prefetcher_pipeline, + mainloop_pipe_producer_state, + gB_nkl, + blk_coord, + k_tile_iter, k_tile_count, + lane_idx, + block_rank_in_cluster, + shared_storage.tensors.mainloop + ); + } + else { + collective_mainloop.load( + params.mainloop, + mainloop_pipeline, + prefetcher_pipeline, + mainloop_pipe_producer_state, + gA_mkl, gB_nkl, + blk_coord, + k_tile_iter, k_tile_count, + lane_idx, + block_rank_in_cluster, + shared_storage.tensors.mainloop + ); + } + // Update starting mainloop pipeline state for the pipeline drain + mainloop_pipe_producer_state.advance(k_tile_count); + // Make sure mainloop consumer has been waited upon before issuing epilogue load + collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); + + if (collective_epilogue.is_producer_load_needed()) { + // Ensure warp is converged before issuing epilogue loads + __syncwarp(); + epi_load_pipe_producer_state = collective_epilogue.load( + epi_load_pipeline, + epi_load_pipe_producer_state, + problem_shape_MNKL, + blk_shape, + blk_coord, + tiled_mma, + lane_idx, + shared_storage.tensors.epilogue + ); + collective_epilogue.load_tail(epi_load_pipeline, epi_load_pipe_producer_state); + } + } + else if (SplitWarps && producer_warp_role == ProducerWarpRole::Warp2) { + collective_mainloop.load_MK( + params.mainloop, + mainloop_pipeline, + prefetcher_pipeline, + mainloop_pipe_producer_state, + gA_mkl, + blk_coord, + k_tile_iter, k_tile_count, + lane_idx, + block_rank_in_cluster, + shared_storage.tensors.mainloop + ); + // Update starting mainloop pipeline state for the pipeline drain + mainloop_pipe_producer_state.advance(k_tile_count); + // Make sure mainloop consumer has been waited upon before issuing epilogue load + collective_mainloop.load_tail(mainloop_pipeline, mainloop_pipe_producer_state); + } else if (producer_warp_role == ProducerWarpRole::PrefetchMK && should_prefetch) { + collective_mainloop.prefetch_MK( + params.mainloop, + prefetcher_pipeline, + mainloop_pipe_producer_state, + gA_mkl, + blk_coord, + k_tile_iter, k_tile_count, + lane_idx, + block_rank_in_cluster, + shared_storage.tensors.mainloop + ); + } + } + else if (warp_group_role == WarpGroupRole::Consumer) { + Tensor accumulators = partition_fragment_C(tiled_mma, take<0,2>(blk_shape)); // (MMA,MMA_M,MMA_N) + + collective_mainloop.mma( + mainloop_pipeline, + mainloop_pipe_consumer_state, + accumulators, + k_tile_count, + warp_group_thread_idx, + shared_storage.tensors.mainloop, + params.mainloop + ); + + // Make sure the math instructions are done and free buffers before entering the epilogue + collective_mainloop.mma_tail( + mainloop_pipeline, + mainloop_pipe_consumer_state, + k_tile_count + ); + + // Epilogue and write to gD + auto [epi_load_pipe_consumer_state_next, epi_store_pipe_producer_state_next] = + collective_epilogue.store( + epi_load_pipeline, + epi_load_pipe_consumer_state, + epi_store_pipeline, + epi_store_pipe_producer_state, + problem_shape_MNKL, + blk_shape, + blk_coord, + accumulators, + tiled_mma, + warp_group_thread_idx, + shared_storage.tensors.epilogue + ); + + collective_epilogue.store_tail( + epi_load_pipeline, + epi_load_pipe_consumer_state_next, + epi_store_pipeline, + epi_store_pipe_producer_state_next + ); + } +#endif + } +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::kernel diff --git a/examples/63_hopper_gemm_with_weight_prefetch/pipeline/prefetch_pipeline_sm90.hpp b/examples/63_hopper_gemm_with_weight_prefetch/pipeline/prefetch_pipeline_sm90.hpp new file mode 100644 index 0000000000..7abd39ccfc --- /dev/null +++ b/examples/63_hopper_gemm_with_weight_prefetch/pipeline/prefetch_pipeline_sm90.hpp @@ -0,0 +1,161 @@ +/*************************************************************************************************** + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cute/arch/cluster_sm90.hpp" +#include "cutlass/arch/barrier.h" +#include "cute/container/array.hpp" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +namespace detail { + +// MSVC work-around +template +struct PrefetcherPipelineSharedStorage { + using TransactionBarrier = cutlass::arch::ClusterTransactionBarrier; + using Barrier = cutlass::arch::ClusterBarrier; + + TransactionBarrier tma_barrier[Stages]; + Barrier producer_ready_barrier; +}; + +} // end namespace detail + +using namespace cute; + +// Prefetcher pipeline is modeled after PipelineTmaAsync, with a cluster transaction +// barrier providing control over the number of concurrent outstanding TMA loads. +// There is also an additional cluster barrier which is only used when `prefetch_ratio` is unset. +// `prefetch_ratio` determines how many K tiles get loaded, and when unset, the prefetcher checks +// whether DMA warps are done waiting on griddepcontrol, and if so, stops issuing more TMA loads. +template +class PrefetchPipeline { +public : + static constexpr uint32_t Stages = Stages_; + using SharedStorage = detail::PrefetcherPipelineSharedStorage; + + using TransactionBarrier = typename SharedStorage::TransactionBarrier; + using Barrier = typename SharedStorage::Barrier; + using PrefetcherBarrierType = typename TransactionBarrier::ValueType; + + struct Params { + uint32_t transaction_bytes = 0; + uint32_t num_prefetchers = 1; + bool should_prefetch = false; + }; + + // Constructor + CUTLASS_DEVICE + PrefetchPipeline(SharedStorage& storage, Params params) + : params_(params) + , tma_barrier_ptr_(&storage.tma_barrier[0]) + , producer_ready_barrier_ptr_(&storage.producer_ready_barrier) { + + int lane_predicate = cute::elect_one_sync(); + if (params.should_prefetch && lane_predicate) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Stages; ++i) { + tma_barrier_ptr_[i].init(params.num_prefetchers); + } + producer_ready_barrier_ptr_[0].init(1); + } + } + + CUTLASS_DEVICE + void producer_arrive() { + if (params_.should_prefetch) { + producer_ready_barrier_ptr_[0].arrive(); + } + } + + CUTLASS_DEVICE + bool have_producers_arrived() { + if (params_.should_prefetch) { + uint32_t barrier_status_ = producer_ready_barrier_ptr_[0].try_wait(0); + auto barrier_status = static_cast(barrier_status_); + if (barrier_status == BarrierStatus::WaitDone) { + return true; // exit prefetcher loop + } + return false; + } + return true; + } + + CUTLASS_DEVICE + void prefetcher_acquire(uint32_t stage, uint32_t phase, bool should_wait) { + if (params_.should_prefetch) { + if (should_wait) { + tma_barrier_ptr_[stage].wait(phase ^ 1); + } + tma_barrier_ptr_[stage].arrive_and_expect_tx(params_.transaction_bytes); + } + } + + CUTLASS_DEVICE + void advance_prefetcher_state(uint32_t& stage, uint32_t& phase) { + if (params_.should_prefetch) { + stage++; + if (stage == Stages) { + stage = 0; + phase ^= 1; + } + } + } + + CUTLASS_DEVICE + void prefetcher_tail(uint32_t stage, uint32_t phase) { + if (params_.should_prefetch) { + // Wait on any already-issued loads + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < stage; ++i) { + tma_barrier_ptr_[i].wait(phase); + } + } + } + + CUTLASS_DEVICE + PrefetcherBarrierType* prefetcher_get_barrier(uint32_t stage) { + return reinterpret_cast(&tma_barrier_ptr_[stage]); + } + +private : + TransactionBarrier* tma_barrier_ptr_ = nullptr; + Barrier* producer_ready_barrier_ptr_ = nullptr; + Params params_; + +}; + +} // end namespace cutlass diff --git a/examples/64_ada_fp8_gemm_grouped/CMakeLists.txt b/examples/64_ada_fp8_gemm_grouped/CMakeLists.txt new file mode 100644 index 0000000000..183202593c --- /dev/null +++ b/examples/64_ada_fp8_gemm_grouped/CMakeLists.txt @@ -0,0 +1,35 @@ + +# Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + + +cutlass_example_add_executable( + 64_ada_fp8_gemm_grouped + ada_fp8_gemm_grouped.cu + ) diff --git a/examples/64_ada_fp8_gemm_grouped/ada_fp8_gemm_grouped.cu b/examples/64_ada_fp8_gemm_grouped/ada_fp8_gemm_grouped.cu new file mode 100644 index 0000000000..8e3dbbb08b --- /dev/null +++ b/examples/64_ada_fp8_gemm_grouped/ada_fp8_gemm_grouped.cu @@ -0,0 +1,1208 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Ada FP8 GEMM Grouped With Per-Group Scale Example. + + This workload computes a batch of GEMM operations with distinct problem sizes. Pointers to matrices + in Global Memory are passed to the kernel in array (also held in Global Memory). Similarly, + leading dimensions and problem sizes are stored in arrays in GMEM. + + This differs from "Batched Array" GEMM because the size of each GEMM problem in the Grouped GEMM + concept may be distinct. + + The differences between this and the examples/24_gemm_grouped are: (1) this example scales the output of each GEMM by a different scalar value specified by alpha_ptr_array. (2) this example uses FP8 tensorcore. + + This benchmark program initializes a workspace with random problem sizes for a given number of + groups. Command line options enable overriding M, N, and/or K dimensions with uniform values to + model problems more similar to the traditional batched GEMM. + + Additionally, problem sizes are collected and binned to compute the same problem as a series of + conventional batched GEMMs (setup for this problem is not timed). This demonstrates the performance + enhancement achieved by implementing a specialized grouped GEMM kernel. + + Examples: + + # Runs a grouped GEMM with 100 random problem sizes + $ ./examples/64_ada_fp8_gemm_grouped/64_ada_fp8_gemm_grouped --groups=100 + + # Runs a grouped GEMM with 100 random problem sizes (with GEMM-K dimension equal to 1024) + $ ./examples/64_ada_fp8_gemm_grouped/64_ada_fp8_gemm_grouped --groups=100 --k=1024 --verbose=true + + # Runs a grouped GEMM that is equivalent to a batched GEMM + $ ./examples/64_ada_fp8_gemm_grouped/64_ada_fp8_gemm_grouped --groups=100 --m=2048 --n=1024 --k=1024 --verbose=true + + # Execute Grouped GEMM and profile with NSight + $ nv-nsight-cu-cli ./examples/64_ada_fp8_gemm_grouped/64_ada_fp8_gemm_grouped --m=256 --n=256 --k=256 --verbose=true \ + --iterations=1 --reference-check=false + +*/ + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include +#include +#include +#include +#include +#include +#include + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/kernel/gemm_grouped_per_group_scale.h" +#include "cutlass/gemm/kernel/default_gemm_grouped_per_group_scale.h" +#include "cutlass/gemm/device/gemm_grouped.h" +#include "cutlass/gemm/device/gemm_universal.h" + +#include "cutlass/util/command_line.h" +#include "cutlass/util/distribution.h" +#include "cutlass/util/device_memory.h" +#include "cutlass/util/tensor_view_io.h" +#include "cutlass/util/host_tensor.h" +#include "cutlass/util/reference/host/gemm_complex.h" +#include "cutlass/util/reference/device/gemm.h" +#include "cutlass/util/reference/device/gemm_complex.h" +#include "cutlass/util/reference/host/tensor_compare.h" +#include "cutlass/util/reference/host/tensor_copy.h" +#include "cutlass/util/reference/device/tensor_fill.h" +#include "cutlass/util/reference/host/tensor_norm.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Result structure +struct Result { + + double runtime_ms; + double initialization_time_ms; + double gflops; + cutlass::Status status; + cudaError_t error; + bool passed; + + // + // Methods + // + + Result( + double runtime_ms = 0, + double initialization_time_ms = 0, + double gflops = 0, + cutlass::Status status = cutlass::Status::kSuccess, + cudaError_t error = cudaSuccess + ): + runtime_ms(runtime_ms), initialization_time_ms(initialization_time_ms), gflops(gflops), + status(status), error(error), passed(true) { } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Hash function for cutlass::gemm::GemmCoord +struct HashGemmCoord { + size_t operator()(cutlass::gemm::GemmCoord const &problem) const { + std::hash hasher; + return (hasher(problem.m() * 3)) ^ (hasher(1 + problem.n() * 5)) ^ (hasher(2 + problem.k() * 7)); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Command line options parsing +struct Options { + + bool help; + bool error; + bool reference_check; + bool profile_initialization; + bool sort_problems; + + std::vector problem_sizes; + + // problem size bins + std::unordered_map< + cutlass::gemm::GemmCoord, + std::vector, + HashGemmCoord> problem_bins; + + int alignment; + int problem_count; + int iterations; + int cuda_streams; + bool verbose; + float alpha; + std::vector alpha_array; + float beta; + std::string benchmark_path; + + std::string output_tag; + std::ofstream output_file; + + using GroupScheduleMode = cutlass::gemm::kernel::GroupScheduleMode; + std::vector scheduler_modes; + + std::unordered_map + str_to_scheduler_mode = { + {"kDeviceOnly", GroupScheduleMode::kDeviceOnly}, + {"kHostPrecompute", GroupScheduleMode::kHostPrecompute} + }; + + struct GroupScheduleModeHash { + size_t operator()(GroupScheduleMode m) const { + return static_cast(m); + } + }; + + std::unordered_map + scheduler_mode_to_str = { + {GroupScheduleMode::kDeviceOnly, "kDeviceOnly"}, + {GroupScheduleMode::kHostPrecompute, "kHostPrecompute"} + }; + + std::vector all_scheduler_modes = {GroupScheduleMode::kDeviceOnly, GroupScheduleMode::kHostPrecompute}; + + // + // Methods + // + + Options(): + help(false), + error(false), + alignment(16), + reference_check(true), + profile_initialization(false), + sort_problems(false), + problem_count(15), + iterations(20), + cuda_streams(0), + verbose(false), + alpha(1), + beta(), + scheduler_modes({GroupScheduleMode::kDeviceOnly}) + { } + + // Parses the command line + void parse(int argc, char const **args) { + cutlass::CommandLine cmd(argc, args); + + if (cmd.check_cmd_line_flag("help")) { + help = true; + return; + } + + cmd.get_cmd_line_argument("alignment", alignment, 16); + cmd.get_cmd_line_argument("groups", problem_count, 15); + cmd.get_cmd_line_argument("alpha", alpha, 1.0f); + cmd.get_cmd_line_argument("beta", beta, 0.0f); + cmd.get_cmd_line_argument("iterations", iterations, 20); + cmd.get_cmd_line_argument("streams", cuda_streams, 0); + cmd.get_cmd_line_argument("verbose", verbose, false); + cmd.get_cmd_line_argument("reference-check", reference_check, true); + cmd.get_cmd_line_argument("profile-initialization", profile_initialization, false); + cmd.get_cmd_line_argument("sort-problems", sort_problems, false); + cmd.get_cmd_line_argument("benchmark", benchmark_path); + + std::vector scheduler_mode_strs; + cmd.get_cmd_line_arguments("scheduler-modes", scheduler_mode_strs); + + if (!scheduler_mode_strs.empty()) { + scheduler_modes.clear(); + if (scheduler_mode_strs.size() == 1 && scheduler_mode_strs[0] == "all") { + scheduler_modes = all_scheduler_modes; + } else { + for (std::string precomp_str : scheduler_mode_strs) { + auto it = str_to_scheduler_mode.find(precomp_str); + if (it != str_to_scheduler_mode.end()) { + scheduler_modes.push_back(it->second); + } else if (precomp_str == "all") { + std::cerr << "Flag --scheduler-modes=all must not contain other scheduler modes in list." << std::endl; + error = true; + return; + } else { + std::cerr << "Unrecognized scheduler mode '" << precomp_str << "'" << std::endl; + error = true; + return; + } + } + } + } + + std::string output_path; + cmd.get_cmd_line_argument("tag", output_tag); + cmd.get_cmd_line_argument("output_file", output_path); + + if (!output_path.empty()) { + + std::ios_base::openmode open_mode = std::ios_base::out; + + std::ifstream input_file(output_path.c_str()); + + if (input_file.good()) { + open_mode = std::ios_base::app; + input_file.close(); + } + + output_file.open(output_path.c_str(), open_mode); + + if (output_file.good() && open_mode != std::ios_base::app) { + output_file << "Tag,Provider,Kind,Groups,Runtime,GFLOPs\n"; + } + } + + // Decide how to initialize the problems + if (!benchmark_path.empty()) { + if (!benchmark_problems()) { + error = true; + problem_sizes.clear(); + return; + } + } + else { + randomize_problems(cmd); + } + + // Post-process the problem sizes + bin_problems(); + + // Initalize alpha array + randomize_alpha_ptr_array(cmd); + } + + void randomize_problems(cutlass::CommandLine &cmd) { + + // + // For now, randomly choose the problem sizes. + // + + int cmd_line_m = -1; + int cmd_line_n = -1; + int cmd_line_k = -1; + + cmd.get_cmd_line_argument("m", cmd_line_m); + cmd.get_cmd_line_argument("n", cmd_line_n); + cmd.get_cmd_line_argument("k", cmd_line_k); + + problem_sizes.reserve(problem_count); + + for (int i = 0; i < problem_count; ++i) { + + int m = cmd_line_m; + int n = cmd_line_n; + int k = cmd_line_k; + + if (m < 1) { + m = alignment * ((rand() % 256) + 1); + } + + if (n < 1) { + n = alignment * ((rand() % 256) + 1); + } + + if (k < 1) { + k = alignment * ((rand() % 256) + 1); + } + + cutlass::gemm::GemmCoord problem(m, n, k); + + problem_sizes.push_back(problem); + } + } + + void randomize_alpha_ptr_array(cutlass::CommandLine &cmd) { + alpha_array.resize(problem_count); + for (int i = 0; i < problem_count; ++i) { + alpha_array[i] = static_cast((rand() % 100) - 50 + alpha); + } + } + + /// Load a benchmark + bool benchmark_problems() { + std::ifstream file(benchmark_path); + if (!file.good()) { + return false; + } + + while (file.good()) { + + int idx = -1; + std::string extent_str; + + file >> idx >> extent_str; + + if (idx < 0 || extent_str.empty()) { + break; + } + + cutlass::gemm::GemmCoord extent; + std::vector tokens; + + cutlass::CommandLine::tokenize(tokens, extent_str, 'x'); + + for (int i = 0; i < int(tokens.size()); ++i) { + int x = std::atoi(tokens.at(i).c_str()); + + // round up + if (x % alignment) { + x += (alignment - (x % alignment)); + } + + extent.at(i) = x; + } + + if (extent.product()) { + problem_sizes.push_back(extent); + } + } + + return true; + } + + /// Post processes the problems + void bin_problems() { + + problem_bins.clear(); + + problem_count = int(problem_sizes.size()); + + // + // Insert the problem sizes into a sorted container class. This is *NOT* necessary + // to run the CUTLASS kernel, but it enables the execution of cublas's batched GEMM. + // + for (int i = 0; i < int(problem_sizes.size()); ++i) { + auto it = problem_bins.find(problem_sizes.at(i)); + if (it == problem_bins.end()) { + problem_bins.insert({problem_sizes.at(i), std::vector({i}) }); + } + else { + it->second.push_back(i); + } + } + } + + /// Prints the usage statement. + std::ostream & print_usage(std::ostream &out) const { + + out << "64_ada_fp8_gemm_grouped\n\n" + << " This example profiles the performance of a 'grouped' GEMM kernel. This is similar to batched GEMM\n" + << " in that multiple, independent GEMMs are computed by one grid launch. It differs in that each\n" + << " 'group' may compute a unique problem size. Problem sizes and pointers to matrices are both stored\n" + << " in device Global Memory and loaded by the kernel.\n\n" + << "Options:\n\n" + << " --help If specified, displays this usage statement.\n\n" + << " --benchmark= Executes a benchmark problem size.\n" + << " --output_file= Path to a CSV file to output results. If it exists already, results are appended.\n" + << " --tag= String tag to prepend to the CSV file.\n" + << " --groups= Number of individual GEMM problems (default: --groups=15)\n" + << " --m= Sets the M dimension for all groups. Otherwise, it is selected randomly\n" + << " --n= Sets the N dimension for all groups. Otherwise, it is selected randomly\n" + << " --k= Sets the K dimension for all groups. Otherwise, it is selected randomly\n" + << " --alpha= Epilogue scalar alpha (real part)\n" + << " --beta= Epilogue scalar beta (real part)\n" + << " --scheduler-modes= List of scheduler modes to be profile for grouped GEMM scheduler (default: --scheduler_modes=kDeviceOnly)\n" + << " --iterations= Number of profiling iterations to perform.\n" + << " --reference-check= If true, performs reference check.\n" + << " --verbose= If true, prints problem sizes and batching structure.\n" + << " --profile-initialization= If true, profiles the device-level kernel's initialization.\n" + << " --sort-problems= If true, sorts problem sizes in descending order of GEMM-K dimension.\n"; + + out << "\n\nExamples:\n\n" + + << "# Runs a grouped GEMM with 100 random problem sizes\n" + << "$ ./examples/64_ada_fp8_gemm_grouped/64_ada_fp8_gemm_grouped --groups=100\n\n" + + << "# Runs a grouped GEMM with 100 random problem sizes (with GEMM-K dimension equal to 1024)\n" + << "$ ./examples/64_ada_fp8_gemm_grouped/64_ada_fp8_gemm_grouped --groups=100 --k=1024 --verbose=true\n\n" + + << "# Runs a grouped GEMM that is equivalent to a batched GEMM\n" + << "$ ./examples/64_ada_fp8_gemm_grouped/64_ada_fp8_gemm_grouped --groups=100 --m=2048 --n=1024 --k=1024 --verbose=true\n\n" + + << "# Runs a grouped GEMM with each different scheduler mode\n" + << "$ ./examples/64_ada_fp8_gemm_grouped/64_ada_fp8_gemm_grouped --scheduler-modes=all\n\n" + + << "# Runs a grouped GEMM with each different scheduler mode and profiles host-side initialization time\n" + << "$ ./examples/64_ada_fp8_gemm_grouped/64_ada_fp8_gemm_grouped --scheduler-modes=all --profile-initialization=true\n\n" + + << "# Runs a grouped GEMM problem given an externally supplied benchmark file. This is a text file in which\n" + << "# Each line contains a unique group index and an MxNxK triple indicating problemsize.\n" + << "#\n" + << "# For example, assume the following are the contents of 'problems.txt'\n" + << "#\n" + << "# 0 1024x256x520\n" + << "# 1 520x264x1024\n" + << "# 2 96x48x1024\n" + << "#\n" + << "$ ./examples/64_ada_fp8_gemm_grouped/64_ada_fp8_gemm_grouped --benchmark=problems.txt\n\n" + + << "# Execute Grouped GEMM and profile with NSight\n" + << "$ nv-nsight-cu-cli ./examples/64_ada_fp8_gemm_grouped/64_ada_fp8_gemm_grouped --m=256 --n=256 --k=256 --verbose=true --iterations=1 --reference-check=false\n\n"; + + return out; + } + + /// Compute performance in GFLOP/s + double gflops(double runtime_s) const { + + // Number of real-valued multiply-adds + int64_t fmas = int64_t(); + + for (auto const & problem : problem_sizes) { + fmas += problem.product(); + } + + // Two flops per multiply-add + return 2.0 * double(fmas) / double(1.0e9) / runtime_s; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template +class BaseTestbed { +public: + // + // Type definitions + // + + using ElementA = typename Gemm::ElementA; + using ElementB = typename Gemm::ElementB; + using ElementC = typename Gemm::ElementC; + using ElementAccumulator = typename Gemm::ElementAccumulator; + + using EpilogueOutputOp = typename Gemm::GemmKernel::Epilogue::OutputOp; + using ElementCompute = typename EpilogueOutputOp::ElementCompute; + + using LayoutA = typename Gemm::LayoutA; + using LayoutB = typename Gemm::LayoutB; + using LayoutC = typename Gemm::LayoutC; + + using MatrixCoord = typename LayoutC::TensorCoord; + + using DeviceGemmReference = cutlass::reference::device::Gemm< + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + ElementAccumulator>; + + // + // Data members + // + + Options & options; + + /// Initialization + cutlass::Distribution::Kind init_A; + cutlass::Distribution::Kind init_B; + cutlass::Distribution::Kind init_C; + uint32_t seed; + + cutlass::DeviceAllocation problem_sizes_device; + + std::vector offset_A; + std::vector offset_B; + std::vector offset_C; + std::vector offset_D; + + std::vector lda_host; + std::vector ldb_host; + std::vector ldc_host; + std::vector ldd_host; + std::vector alpha_ptr_array_host; + + cutlass::DeviceAllocation lda; + cutlass::DeviceAllocation ldb; + cutlass::DeviceAllocation ldc; + cutlass::DeviceAllocation ldd; + + cutlass::DeviceAllocation block_A; + cutlass::DeviceAllocation block_B; + cutlass::DeviceAllocation block_C; + cutlass::DeviceAllocation block_D; + cutlass::DeviceAllocation alpha_array_device; + + cutlass::DeviceAllocation ptr_A; + cutlass::DeviceAllocation ptr_B; + cutlass::DeviceAllocation ptr_C; + cutlass::DeviceAllocation ptr_D; + cutlass::DeviceAllocation alpha_ptr_array_device; + + BaseTestbed( + Options &options_, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint32_t seed_ = 3080 + ): + options(options_), init_A(init_A_), init_B(init_B_), init_C(init_C_), seed(seed_) { } + + int problem_count() const { + return options.problem_count; + } + + /// Helper to initialize a tensor view + template + void initialize_tensor( + Element *ptr, + size_t capacity, + cutlass::Distribution::Kind dist_kind, + uint32_t seed) { + + if (dist_kind == cutlass::Distribution::Uniform) { + + Element scope_max, scope_min; + int bits_input = cutlass::sizeof_bits::value; + int bits_output = cutlass::sizeof_bits::value; + + if (bits_input == 1) { + scope_max = static_cast(2); + scope_min = static_cast(0); + } else if (bits_input <= 8) { + scope_max = static_cast(2); + scope_min = static_cast(-2); + } else if (bits_output == 16) { + if (cutlass::sizeof_bits::value <= 16) { + scope_max = static_cast(5); + scope_min = static_cast(-5); + } + else { + scope_max = static_cast(8); + scope_min = static_cast(-8); + } + } else { + scope_max = static_cast(8); + scope_min = static_cast(-8); + } + + cutlass::reference::device::BlockFillRandomUniform( + ptr, capacity, seed, scope_max, scope_min, 0); + } + else if (dist_kind == cutlass::Distribution::Gaussian) { + + cutlass::reference::device::BlockFillRandomGaussian( + ptr, capacity, seed, Element(), Element(0.5f)); + } + else if (dist_kind == cutlass::Distribution::Sequential) { + + // Fill with increasing elements + cutlass::reference::device::BlockFillSequential( + ptr, capacity, Element(1), Element()); + } + else { + + // Fill with all 1s + cutlass::reference::device::BlockFillSequential( + ptr, capacity, Element(), Element(1)); + } + } + + /// Allocates device-side data + void allocate() { + int64_t total_elements_A = 0; + int64_t total_elements_B = 0; + int64_t total_elements_C = 0; + int64_t total_elements_D = 0; + + lda_host.resize(problem_count()); + ldb_host.resize(problem_count()); + ldc_host.resize(problem_count()); + ldd_host.resize(problem_count()); + + for (int32_t i = 0; i < problem_count(); ++i) { + + auto problem = options.problem_sizes.at(i); + + lda_host.at(i) = LayoutA::packed({problem.m(), problem.k()}).stride(0); + ldb_host.at(i) = LayoutB::packed({problem.k(), problem.n()}).stride(0); + ldc_host.at(i) = LayoutC::packed({problem.m(), problem.n()}).stride(0); + ldd_host.at(i) = LayoutC::packed({problem.m(), problem.n()}).stride(0); + + offset_A.push_back(total_elements_A); + offset_B.push_back(total_elements_B); + offset_C.push_back(total_elements_C); + offset_D.push_back(total_elements_D); + + int64_t elements_A = problem.m() * problem.k(); + int64_t elements_B = problem.k() * problem.n(); + int64_t elements_C = problem.m() * problem.n(); + int64_t elements_D = problem.m() * problem.n(); + + total_elements_A += elements_A; + total_elements_B += elements_B; + total_elements_C += elements_C; + total_elements_D += elements_D; + } + + lda.reset(problem_count()); + ldb.reset(problem_count()); + ldc.reset(problem_count()); + ldd.reset(problem_count()); + + block_A.reset(total_elements_A); + block_B.reset(total_elements_B); + block_C.reset(total_elements_C); + block_D.reset(total_elements_D); + + alpha_ptr_array_host.resize(problem_count()); + alpha_array_device.reset(problem_count()); + alpha_ptr_array_device.reset(problem_count()); + } + + /// Initializes device-side data + void initialize() { + problem_sizes_device.reset(problem_count()); + problem_sizes_device.copy_from_host(options.problem_sizes.data()); + + lda.copy_from_host(lda_host.data()); + ldb.copy_from_host(ldb_host.data()); + ldc.copy_from_host(ldc_host.data()); + ldd.copy_from_host(ldd_host.data()); + + // + // Assign pointers + // + + std::vector ptr_A_host(problem_count()); + std::vector ptr_B_host(problem_count()); + std::vector ptr_C_host(problem_count()); + std::vector ptr_D_host(problem_count()); + + for (int32_t i = 0; i < problem_count(); ++i) { + ptr_A_host.at(i) = block_A.get() + offset_A.at(i); + ptr_B_host.at(i) = block_B.get() + offset_B.at(i); + ptr_C_host.at(i) = block_C.get() + offset_C.at(i); + ptr_D_host.at(i) = block_D.get() + offset_D.at(i); + } + + ptr_A.reset(problem_count()); + ptr_A.copy_from_host(ptr_A_host.data()); + + ptr_B.reset(problem_count()); + ptr_B.copy_from_host(ptr_B_host.data()); + + ptr_C.reset(problem_count()); + ptr_C.copy_from_host(ptr_C_host.data()); + + ptr_D.reset(problem_count()); + ptr_D.copy_from_host(ptr_D_host.data()); + + // + // Initialize the problems of the workspace + // + + initialize_tensor(block_A.get(), block_A.size(), init_A, seed * 2021); + initialize_tensor(block_B.get(), block_B.size(), init_B, seed * 2022); + initialize_tensor(block_C.get(), block_C.size(), init_C, seed * 2023); + + cutlass::reference::device::BlockFillSequential( + block_D.get(), block_D.size(), ElementC(), ElementC()); + + // Initialize alpha array + alpha_array_device.copy_from_host(options.alpha_array.data()); + for (int32_t i = 0; i < problem_count(); ++i) { + alpha_ptr_array_host.at(i) = alpha_array_device.get() + i; + } + alpha_ptr_array_device.copy_from_host(alpha_ptr_array_host.data()); + } + + /// Verifies the result is a GEMM + bool verify() { + + bool passed = true; + + for (int32_t i = 0; i < problem_count(); ++i) { + cutlass::gemm::GemmCoord problem = options.problem_sizes.at(i); + + LayoutA layout_A(lda_host.at(i)); + LayoutB layout_B(ldb_host.at(i)); + LayoutC layout_C(ldc_host.at(i)); + LayoutC layout_D(ldd_host.at(i)); + + MatrixCoord extent_A{problem.m(), problem.k()}; + MatrixCoord extent_B{problem.k(), problem.n()}; + MatrixCoord extent_C{problem.m(), problem.n()}; + + cutlass::TensorView view_A(block_A.get() + offset_A.at(i), layout_A, extent_A); + cutlass::TensorView view_B(block_B.get() + offset_B.at(i), layout_B, extent_B); + cutlass::TensorView view_C(block_C.get() + offset_C.at(i), layout_C, extent_C); + + cutlass::DeviceAllocation block_Ref(layout_D.capacity(extent_C)); + cutlass::TensorView view_Ref_device(block_Ref.get(), layout_D, extent_C); + + // Reference GEMM + cutlass::reference::device::GemmComplex< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementCompute, ElementAccumulator + >( + problem, + options.alpha_array[i], + view_A, + Gemm::kTransformA, + view_B, + Gemm::kTransformB, + options.beta, + view_C, + view_Ref_device, + ElementAccumulator(0) + ); + + // Copy to host memory + std::vector matrix_D(layout_D.capacity(extent_C)); + std::vector matrix_Ref(layout_D.capacity(extent_C)); + + cutlass::device_memory::copy_to_host(matrix_D.data(), block_D.get() + offset_D.at(i), matrix_D.size()); + cutlass::device_memory::copy_to_host(matrix_Ref.data(), block_Ref.get(), matrix_D.size()); + + cutlass::TensorView view_D( matrix_D.data(), layout_D, extent_C); + cutlass::TensorView view_Ref(matrix_Ref.data(), layout_D, extent_C); + + // Reference check + passed = cutlass::reference::host::TensorEquals(view_D, view_Ref); + + if (!passed) { + std::cerr << "\n***\nError - problem " << i << " failed the QA check\n***\n" << std::endl; + return passed; + } + } + + return passed; + } + +}; + +template +class TestbedGrouped : BaseTestbed { +public: + TestbedGrouped( + Options &options_, + cutlass::Distribution::Kind init_A_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_B_ = cutlass::Distribution::Uniform, + cutlass::Distribution::Kind init_C_ = cutlass::Distribution::Uniform, + uint32_t seed_ = 3080 + ): BaseTestbed(options_, init_A_, init_B_, init_C_, seed_) {} + + // Redefine GEMM with different GroupScheduleMode_ + using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGroupedPerGroupScale< + typename Gemm_::ElementA, + typename Gemm_::LayoutA, + Gemm_::kTransformA, + Gemm_::kAlignmentA, + typename Gemm_::ElementB, + typename Gemm_::LayoutB, + Gemm_::kTransformB, + Gemm_::kAlignmentB, + typename Gemm_::ElementC, + typename Gemm_::LayoutC, + typename Gemm_::ElementAccumulator, + typename Gemm_::OperatorClass, + typename Gemm_::ArchTag, + typename Gemm_::ThreadblockShape, + typename Gemm_::WarpShape, + typename Gemm_::InstructionShape, + typename Gemm_::EpilogueOutputOp, + typename Gemm_::ThreadblockSwizzle, + Gemm_::kStages, + GroupScheduleMode_>::GemmKernel; + + using Gemm = cutlass::gemm::device::GemmGrouped; + + /// Verbose printing of problem sizes + void print_problem_sizes() { + std::cout << std::endl; + + // Print groups + std::cout << this->problem_count() << " groups:\n"; + + int32_t idx = 0; + int64_t total_tiles = 0; + + for (auto const & problem : this->options.problem_sizes) { + int tiles = Gemm::problem_tile_count(problem); + total_tiles += tiles; + + std::cout << " [" << idx << "]: " + << problem.m() << "-by-" << problem.n() << "-by-" << problem.k() + << " (" << tiles << " threadblock tiles)" << "\n"; + + ++idx; + } + std::cout << std::endl; + } + + /// Sort problems in descending order of problem-K dimension + void sort_problems() { + Gemm::sort_problems(this->options.problem_count, + this->options.problem_sizes.data(), + this->lda_host.data(), + this->ldb_host.data(), + this->ldc_host.data(), + this->ldd_host.data(), + this->offset_A.data(), + this->offset_B.data(), + this->offset_C.data(), + this->offset_D.data()); + } + + /// Executes a grouped kernel and measures runtime + Result profile() { + std::string sched_mode = this->options.scheduler_mode_to_str.find(GroupScheduleMode_)->second; + + std::cout << std::endl; + std::cout << "Grouped GEMM (CUTLASS) with mode " << sched_mode << ":\n" + << "====================================================" << std::endl; + + Result result; + + int threadblock_count = Gemm::sufficient(this->options.problem_sizes.data(), this->options.problem_count); + + // Early exit + if (!threadblock_count) { + std::cout << "Active CUDA device lacks hardware resources to run CUTLASS Grouped GEMM kernel." << std::endl; + return result; + } + + result.passed = false; + + // Initialize the problem + this->allocate(); + if (this->options.sort_problems) { + sort_problems(); + } + this->initialize(); + + if (this->options.verbose) { + print_problem_sizes(); + } + + // Configure the GEMM arguments + typename Gemm::EpilogueOutputOp::ElementCompute ** alpha_ptr_array = this->alpha_ptr_array_device.get(); + typename Gemm::EpilogueOutputOp::Params epilogue_op(alpha_ptr_array, nullptr); + + // Configure GEMM arguments + typename Gemm::Arguments args( + this->problem_sizes_device.get(), + this->problem_count(), + threadblock_count, + epilogue_op, + this->ptr_A.get(), + this->ptr_B.get(), + this->ptr_C.get(), + this->ptr_D.get(), + this->lda.get(), + this->ldb.get(), + this->ldc.get(), + this->ldd.get(), + this->options.problem_sizes.data() + ); + + // Initialize the GEMM object + Gemm gemm; + + size_t workspace_size = gemm.get_workspace_size(args); + cutlass::DeviceAllocation workspace(workspace_size); + + result.status = gemm.initialize(args, workspace.get()); + + if (result.status != cutlass::Status::kSuccess) { + std::cerr << "Failed to initialize CUTLASS Grouped GEMM kernel." << std::endl; + return result; + } + + // Run the grouped GEMM object + result.status = gemm.run(); + + if (result.status != cutlass::Status::kSuccess) { + std::cerr << "Failed to run CUTLASS Grouped GEMM kernel." << std::endl; + return result; + } + + // Wait for completion + result.error = cudaDeviceSynchronize(); + + if (result.error != cudaSuccess) { + std::cerr << "Kernel execution error: " << cudaGetErrorString(result.error); + return result; + } + + // + // Verify correctness + // + result.passed = true; + + if (this->options.reference_check) { + result.passed = this->verify(); + } + + // + // Warm-up run of the grouped GEMM object + // + result.status = gemm.run(); + + if (result.status != cutlass::Status::kSuccess) { + std::cerr << "Failed to run CUTLASS Grouped GEMM kernel." << std::endl; + return result; + } + + // + // Construct events + // + + cudaEvent_t events[2]; + + for (auto & event : events) { + result.error = cudaEventCreate(&event); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventCreate() failed: " << cudaGetErrorString(result.error) << std::endl; + return -1; + } + } + + // Record an event at the start of a series of GEMM operations + result.error = cudaEventRecord(events[0]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // + // Run profiling loop + // + + for (int iter = 0; iter < this->options.iterations; ++iter) { + gemm(); + } + + // + // Stop profiling loop + // + + // Record an event when the GEMM operations have been launched. + result.error = cudaEventRecord(events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventRecord() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Wait for work on the device to complete. + result.error = cudaEventSynchronize(events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventSynchronize() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Measure elapsed runtime + float runtime_ms = 0; + result.error = cudaEventElapsedTime(&runtime_ms, events[0], events[1]); + if (result.error != cudaSuccess) { + std::cerr << "cudaEventElapsed() failed: " << cudaGetErrorString(result.error) << std::endl; + return result; + } + + // Compute average runtime and GFLOPs. + result.runtime_ms = double(runtime_ms) / double(this->options.iterations); + result.gflops = this->options.gflops(result.runtime_ms / 1000.0); + + // + // Cleanup + // + + for (auto event : events) { + (void)cudaEventDestroy(event); + } + + // Optionally profile initialization + if (this->options.profile_initialization) { + // Warm up + gemm.initialize(args, workspace.get()); + + auto start_time = std::chrono::high_resolution_clock::now(); + for (int32_t i = 0; i < this->options.iterations; ++i) { + gemm.initialize(args, workspace.get()); + } + auto end_time = std::chrono::high_resolution_clock::now(); + + std::chrono::duration duration = end_time - start_time; + duration /= double(this->options.iterations); + result.initialization_time_ms = duration.count(); + } + + int64_t total_tiles = Gemm::group_tile_count(args); + std::cout << " " << total_tiles << " total threadblock tiles." << std::endl; + + std::cout << std::endl; + std::cout << " " << "Grouped Runtime: " << result.runtime_ms << " ms" << std::endl; + std::cout << " " << "Grouped GFLOPs: " << result.gflops << std::endl; + if (this->options.profile_initialization) { + std::cout << " " << "Init Runtime: " << result.initialization_time_ms << " ms" << std::endl; + } + + if (this->options.output_file.good()) { + this->options.output_file << this->options.output_tag << ",CUTLASS,grouped-" << sched_mode << "," + << this->options.problem_count << "," << result.runtime_ms << "," << result.gflops << std::endl; + } + + std::cout << "\nPassed\n"; + + return result; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +int main(int argc, char const **args) { + + cudaDeviceProp props; + + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if (__CUDACC_VER_MAJOR__ < 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ < 4)) { + std::cerr << "This example requires CUDA 12.4 or greater." << std::endl; + return 0; + } + int device_idx; + cudaError_t result = cudaGetDevice(&device_idx); + cudaDeviceProp properties; + result = cudaGetDeviceProperties(&properties, device_idx); + + if (result != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() failed with error: " << cudaGetErrorString(result) << std::endl; + return 0; + } + + if (!(properties.major == 8 && properties.minor == 9)) { + std::cerr << "CUTLASS's Ada FP8 Gemm Grouped example requires a device of compute capability 89.\n" << std::endl; + return 0; + } + // + // Parse options + // + + Options options; + + options.parse(argc, args); + + if (options.help) { + options.print_usage(std::cout) << std::endl; + return 0; + } + + if (options.error) { + std::cerr << "Aborting execution." << std::endl; + return -1; + } + + // + // Define the Grouped and Batched GEMM types + // + + using ElementA = cutlass::float_e4m3_t; + using ElementB = cutlass::float_e4m3_t; + using ElementOutput = cutlass::bfloat16_t; + using ElementAccumulator = float; + + using LayoutA = cutlass::layout::RowMajor; + using LayoutB = cutlass::layout::ColumnMajor; + using LayoutC = cutlass::layout::RowMajor; + + constexpr int ElementsPerAccessA = 128 / cutlass::sizeof_bits::value; + constexpr int ElementsPerAccessB = 128 / cutlass::sizeof_bits::value; + + // Define a grouped GEMM kernel with all template parameters set except + // for scheduling mode. This will be used as the template for all scheduling + // modes executed. + using GemmKernel = typename cutlass::gemm::kernel::DefaultGemmGroupedPerGroupScale< + ElementA, + LayoutA, + cutlass::ComplexTransform::kNone, + ElementsPerAccessA, + ElementB, + LayoutB, + cutlass::ComplexTransform::kNone, + ElementsPerAccessB, + ElementOutput, LayoutC, + ElementAccumulator, + cutlass::arch::OpClassTensorOp, + cutlass::arch::Sm89, + cutlass::gemm::GemmShape<64, 128, 64>, + cutlass::gemm::GemmShape<64, 32, 64>, + cutlass::gemm::GemmShape<16, 8, 32>, + cutlass::epilogue::thread::LinearCombination< + ElementOutput, 128 / cutlass::sizeof_bits::value, + ElementAccumulator, ElementAccumulator>, + // NOTE: Threadblock swizzling is currently not supported by CUTLASS's grouped kernels. + // This parameter is passed in at present to match the APIs of other kernels. The parameter + // is unused within the kernel. + cutlass::gemm::threadblock::GemmBatchedIdentityThreadblockSwizzle, + 4>::GemmKernel; + + using GemmGrouped = cutlass::gemm::device::GemmGrouped; + + // + // Profile it + // + + using GroupScheduleMode = cutlass::gemm::kernel::GroupScheduleMode; + for (GroupScheduleMode mode : options.scheduler_modes) { + Result result; + switch (mode) { + case GroupScheduleMode::kDeviceOnly: + { + TestbedGrouped runner(options); + result = runner.profile(); + break; + } + case GroupScheduleMode::kHostPrecompute: + { + TestbedGrouped runner(options); + result = runner.profile(); + break; + } + } + + if (result.error != cudaSuccess) { + return 1; + } + + // Override verbose flag to avoid printing duplicate information for each scheduling mode + options.verbose = false; + } + + return 0; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/examples/CMakeLists.txt b/examples/CMakeLists.txt index d5fdac141b..7e8d45227b 100644 --- a/examples/CMakeLists.txt +++ b/examples/CMakeLists.txt @@ -1,5 +1,5 @@ -# Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: BSD-3-Clause # # Redistribution and use in source and binary forms, with or without @@ -44,7 +44,7 @@ function(cutlass_example_add_executable NAME) set(__DISABLE_TESTS OFF) endif() - cutlass_add_executable(${NAME} ${__UNPARSED_ARGUMENTS}) + cutlass_add_executable(${NAME} ${__UNPARSED_ARGUMENTS} BATCH_SOURCES OFF) add_dependencies(cutlass_examples ${NAME}) @@ -54,12 +54,14 @@ function(cutlass_example_add_executable NAME) CUTLASS cutlass_tools_util_includes $<$:nvidia::cublas> + cuda ) target_include_directories( ${NAME} PRIVATE ${CUTLASS_EXAMPLES_COMMON_SOURCE_DIR} + ${CUTLASS_EXAMPLES_UTILS_DIR} ) install( @@ -116,8 +118,35 @@ foreach(EXAMPLE 34_transposed_conv2d 35_gemm_softmax 36_gather_scatter_fusion + 37_gemm_layernorm_gemm_fusion + 38_syr2k_grouped + cute + 39_gemm_permute + 41_fused_multi_head_attention + 42_ampere_tensorop_group_conv + 43_ell_block_sparse_gemm + 45_dual_gemm + 46_depthwise_simt_conv2dfprop + 47_ampere_gemm_universal_streamk + 48_hopper_warp_specialized_gemm + 49_hopper_gemm_with_collective_builder + 50_hopper_gemm_with_epilogue_swizzle + 51_hopper_gett + 52_hopper_gather_scatter_fusion + 53_hopper_gemm_permute + 54_hopper_fp8_warp_specialized_gemm + 55_hopper_mixed_dtype_gemm + 56_hopper_ptr_array_batched_gemm + 57_hopper_grouped_gemm + 58_ada_fp8_gemm + 59_ampere_gather_scatter_conv + 61_hopper_gemm_with_topk_and_softmax + 62_hopper_sparse_gemm + 63_hopper_gemm_with_weight_prefetch + 64_ada_fp8_gemm_grouped ) add_subdirectory(${EXAMPLE}) endforeach() + diff --git a/examples/common/gather_tensor.hpp b/examples/common/gather_tensor.hpp new file mode 100644 index 0000000000..62616e00c7 --- /dev/null +++ b/examples/common/gather_tensor.hpp @@ -0,0 +1,215 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/layout.hpp" +#include "cute/tensor.hpp" +#include "cute/util/print.hpp" + +namespace example { + +using namespace cute; + +// Empty type used to disable gather/scatter for a GEMM argument +struct NoGather +{ + template + NoGather(Ts...) {}; +}; + +/// Function object that applies an index to its argument +template +struct IndexedGather +{ + CUTE_HOST_DEVICE constexpr + IndexedGather(Index const *indices = {}): indices_(indices) {} + + template + CUTE_HOST_DEVICE constexpr + Index + operator()(I i) const { return indices_[i]; } + + CUTE_HOST_DEVICE friend + void + print(IndexedGather const &s) { + cute::print("Indexed"); + } + + Index const *indices_; +}; + +/// Function object that applies a stride to its argument +/// Example: StridedFunc gathers every other row/column +template +struct StridedGather +{ + CUTE_HOST_DEVICE constexpr + StridedGather(Stride stride = {}): stride_(stride) {} + + template + CUTE_HOST_DEVICE constexpr + auto + operator()(I i) const { return i * stride_; } + + CUTE_HOST_DEVICE friend + void + print(StridedGather const &s) { + cute::print("Strided{"); + print(s.stride_); + cute::print("}"); + } + + Stride stride_; +}; + +/// Custom stride object that applies a function followed by a stride +template +struct CustomStride +{ + CUTE_HOST_DEVICE constexpr + CustomStride(Func const &func, Stride const &stride): func_(func), stride_(stride) {} + + template + CUTE_HOST_DEVICE constexpr friend + auto + operator*(I i, CustomStride const &s) { return s.func_(i) * s.stride_; } + + template + CUTE_HOST_DEVICE constexpr friend + auto + operator*(CustomStride const &s, I i) { return s.func_(i) * s.stride_; } + + CUTE_HOST_DEVICE friend + void + print(CustomStride const & s) { + cute::print("Custom{"); + print(s.func_); + cute::print(","); + print(s.stride_); + cute::print("}"); + } + + template + CUTE_HOST_DEVICE constexpr friend + auto + safe_div(CustomStride const &s, Div const &div) + { + return CustomStride(s.func_, safe_div(s.stride_, div)); + } + + // Circumvent the requirement on make_layout that shape and stride are integral + template + CUTE_HOST_DEVICE constexpr friend + auto + make_layout(Shape const &shape, CustomStride const &stride) + { + return Layout(shape, stride); + } + + Func func_; + Stride stride_; +}; + +template +CUTLASS_HOST_DEVICE +auto +make_custom_stride_layout(Stride const &stride, Func&& func) +{ + // Use a dummy shape and replace the first non-unit stride with a custom gather stride + auto idx = find_if(stride, [](auto x){ return not is_constant<1, decltype(x)>{}; }); + constexpr int I = decltype(idx)::value; + return make_layout(repeat_like(stride, _1{}), + replace(stride, CustomStride{static_cast(func), get(stride)})); +} + +/// Helper function to optionally create a gather tensor +template +CUTLASS_HOST_DEVICE +auto +make_gather_tensor(Iterator iter, Shape const &shape, Stride const &stride, Func &&func) +{ + if constexpr (not cutlass::platform::is_same, NoGather>::value) { + Layout matrix_layout = make_identity_layout(shape); + auto offset = as_arithmetic_tuple(repeat_like(shape, _0{})); + Layout gather_layout = make_custom_stride_layout(stride, static_cast(func)); + return make_tensor(iter, ComposedLayout{gather_layout, offset, matrix_layout}); + } else { + return make_tensor(iter, shape, stride); + } +} + +} // namespace example + +namespace cute +{ + +template +CUTE_HOST_DEVICE constexpr +auto +upcast(Shape const& shape, Stride const& stride) +{ + if constexpr (is_tuple::value) { + return transform_layout(shape, stride, [](auto const& s, auto const& d) { return upcast(s,d); }); + } else if constexpr (is_scaled_basis::value) { + if constexpr (Stride::mode() == I) { + return make_layout(shape_div(shape, Int{}), shape_div(stride, Int{})); + } else { + return make_layout(shape, stride); + } + } else { + return upcast(shape, stride); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +upcast(ComposedLayout,Offset,Layout> const& layout) +{ + // Find index of the stride-1 mode - that is the only one that requires updating inner shape and offset + auto idx = find_if(layout.layout_a().stride(), [](auto x){ return is_constant<1, decltype(x)>{}; }); + constexpr int I = decltype(idx)::value; + + // Upcast the outer layout (works as expected) + auto outer = upcast(layout.layout_a()); + + // Upcast the accumulated offset along stride-1 mode + auto offset = as_arithmetic_tuple(replace(layout.offset(), upcast(get(layout.offset())))); + + // Upcast the inner layout's shape along stride-1 mode + auto inner = upcast(layout.layout_b().shape(), layout.layout_b().stride()); + + return composition(outer, offset, inner); +} + +} // namespace example diff --git a/examples/common/helper.h b/examples/common/helper.h index 2affd96c68..a7a81e7479 100644 --- a/examples/common/helper.h +++ b/examples/common/helper.h @@ -1,7 +1,41 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ #pragma once #include "cuda_runtime.h" +#include +/** + * Panic wrapper for unwinding CUTLASS errors + */ #define CUTLASS_CHECK(status) \ { \ cutlass::Status error = status; \ @@ -12,6 +46,10 @@ } \ } + +/** + * Panic wrapper for unwinding CUDA runtime errors + */ #define CUDA_CHECK(status) \ { \ cudaError_t error = status; \ @@ -21,3 +59,50 @@ exit(EXIT_FAILURE); \ } \ } + + +/** + * GPU timer for recording the elapsed time across kernel(s) launched in GPU stream + */ +struct GpuTimer +{ + cudaStream_t _stream_id; + cudaEvent_t _start; + cudaEvent_t _stop; + + /// Constructor + GpuTimer() : _stream_id(0) + { + CUDA_CHECK(cudaEventCreate(&_start)); + CUDA_CHECK(cudaEventCreate(&_stop)); + } + + /// Destructor + ~GpuTimer() + { + CUDA_CHECK(cudaEventDestroy(_start)); + CUDA_CHECK(cudaEventDestroy(_stop)); + } + + /// Start the timer for a given stream (defaults to the default stream) + void start(cudaStream_t stream_id = 0) + { + _stream_id = stream_id; + CUDA_CHECK(cudaEventRecord(_start, _stream_id)); + } + + /// Stop the timer + void stop() + { + CUDA_CHECK(cudaEventRecord(_stop, _stream_id)); + } + + /// Return the elapsed time (in milliseconds) + float elapsed_millis() + { + float elapsed = 0.0; + CUDA_CHECK(cudaEventSynchronize(_stop)); + CUDA_CHECK(cudaEventElapsedTime(&elapsed, _start, _stop)); + return elapsed; + } +}; diff --git a/examples/cute/CMakeLists.txt b/examples/cute/CMakeLists.txt new file mode 100644 index 0000000000..69aefd7c94 --- /dev/null +++ b/examples/cute/CMakeLists.txt @@ -0,0 +1,30 @@ + +# Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + +add_subdirectory(tutorial) diff --git a/examples/cute/tutorial/CMakeLists.txt b/examples/cute/tutorial/CMakeLists.txt new file mode 100644 index 0000000000..b427d9368c --- /dev/null +++ b/examples/cute/tutorial/CMakeLists.txt @@ -0,0 +1,60 @@ + +# Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +# SPDX-License-Identifier: BSD-3-Clause +# +# Redistribution and use in source and binary forms, with or without +# modification, are permitted provided that the following conditions are met: +# +# 1. Redistributions of source code must retain the above copyright notice, this +# list of conditions and the following disclaimer. +# +# 2. Redistributions in binary form must reproduce the above copyright notice, +# this list of conditions and the following disclaimer in the documentation +# and/or other materials provided with the distribution. +# +# 3. Neither the name of the copyright holder nor the names of its +# contributors may be used to endorse or promote products derived from +# this software without specific prior written permission. +# +# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +# AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +# IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +# DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +# FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +# DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +# SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +# CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +# OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +# OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + + +cutlass_example_add_executable( + sgemm_1 + sgemm_1.cu +) + +cutlass_example_add_executable( + sgemm_2 + sgemm_2.cu +) + +cutlass_example_add_executable( + sgemm_sm70 + sgemm_sm70.cu +) + +cutlass_example_add_executable( + sgemm_sm80 + sgemm_sm80.cu +) + +cutlass_example_add_executable( + tiled_copy + tiled_copy.cu +) + +cutlass_example_add_executable( + wgmma_sm90 + wgmma_sm90.cu +) + diff --git a/examples/cute/tutorial/sgemm_1.cu b/examples/cute/tutorial/sgemm_1.cu new file mode 100644 index 0000000000..e5bf9a9201 --- /dev/null +++ b/examples/cute/tutorial/sgemm_1.cu @@ -0,0 +1,469 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#include +#include +#include + +#include +#include + +#include + +#include "cutlass/util/print_error.hpp" +#include "cutlass/util/GPU_Clock.hpp" +#include "cutlass/util/helper_cuda.hpp" + +template +__global__ static +__launch_bounds__(decltype(size(CThreadLayout{}))::value) +void +gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler, + TA const* A, AStride dA, ASmemLayout sA_layout, AThreadLayout tA, + TB const* B, BStride dB, BSmemLayout sB_layout, BThreadLayout tB, + TC * C, CStride dC, CSmemLayout , CThreadLayout tC, + Alpha alpha, Beta beta) +{ + using namespace cute; + + // Preconditions + CUTE_STATIC_ASSERT_V(rank(shape_MNK) == Int<3>{}); // (M, N, K) + CUTE_STATIC_ASSERT_V(rank(cta_tiler) == Int<3>{}); // (BLK_M, BLK_N, BLK_K) + + static_assert(is_static::value); + static_assert(is_static::value); + static_assert(is_static::value); + + CUTE_STATIC_ASSERT_V(size(tA) == size(tB)); // NumThreads + CUTE_STATIC_ASSERT_V(size(tC) == size(tA)); // NumThreads + + CUTE_STATIC_ASSERT_V(size<0>(cta_tiler) % size<0>(tA) == Int<0>{}); // BLK_M / THR_M + CUTE_STATIC_ASSERT_V(size<2>(cta_tiler) % size<1>(tA) == Int<0>{}); // BLK_K / THR_K + CUTE_STATIC_ASSERT_V(size<1>(cta_tiler) % size<0>(tB) == Int<0>{}); // BLK_N / THR_N + CUTE_STATIC_ASSERT_V(size<2>(cta_tiler) % size<1>(tB) == Int<0>{}); // BLK_K / THR_K + CUTE_STATIC_ASSERT_V(size<0>(cta_tiler) % size<0>(tC) == Int<0>{}); // BLK_M / THR_M + CUTE_STATIC_ASSERT_V(size<1>(cta_tiler) % size<1>(tC) == Int<0>{}); // BLK_N / THR_N + + static_assert(is_static::value); + static_assert(is_static::value); + static_assert(is_static::value); + + CUTE_STATIC_ASSERT_V(size<0>(ASmemLayout{}) == size<0>(cta_tiler)); // BLK_M + CUTE_STATIC_ASSERT_V(size<0>(CSmemLayout{}) == size<0>(cta_tiler)); // BLK_M + CUTE_STATIC_ASSERT_V(size<0>(BSmemLayout{}) == size<1>(cta_tiler)); // BLK_N + CUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<1>(cta_tiler)); // BLK_N + CUTE_STATIC_ASSERT_V(size<1>(ASmemLayout{}) == size<2>(cta_tiler)); // BLK_K + CUTE_STATIC_ASSERT_V(size<1>(BSmemLayout{}) == size<2>(cta_tiler)); // BLK_K + + CUTE_STATIC_ASSERT_V(congruent(select<0,2>(shape_MNK), dA)); // dA strides for shape MK + CUTE_STATIC_ASSERT_V(congruent(select<1,2>(shape_MNK), dB)); // dB strides for shape NK + CUTE_STATIC_ASSERT_V(congruent(select<0,1>(shape_MNK), dC)); // dC strides for shape MN + + // + // Full and Tiled Tensors + // + + // Represent the full tensors + Tensor mA = make_tensor(make_gmem_ptr(A), select<0,2>(shape_MNK), dA); // (M,K) + Tensor mB = make_tensor(make_gmem_ptr(B), select<1,2>(shape_MNK), dB); // (N,K) + Tensor mC = make_tensor(make_gmem_ptr(C), select<0,1>(shape_MNK), dC); // (M,N) + + // Get the appropriate blocks for this thread block + auto cta_coord = make_coord(blockIdx.x, blockIdx.y, _); // (m,n,k) + Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,k) + Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) + Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1,_1, X>{}); // (BLK_M,BLK_N) + + // Shared memory buffers + __shared__ TA smemA[cosize_v]; + __shared__ TB smemB[cosize_v]; + Tensor sA = make_tensor(make_smem_ptr(smemA), sA_layout); // (BLK_M,BLK_K) + Tensor sB = make_tensor(make_smem_ptr(smemB), sB_layout); // (BLK_N,BLK_K) + + // + // Partition the copying of A and B tiles across the threads + // + + // TUTORIAL: Example of simple raked partitioning of ThreadLayouts tA|tB over data A|B tiles + + Tensor tAgA = local_partition(gA, tA, threadIdx.x); // (THR_M,THR_K,k) + Tensor tAsA = local_partition(sA, tA, threadIdx.x); // (THR_M,THR_K) + + Tensor tBgB = local_partition(gB, tB, threadIdx.x); // (THR_N,THR_K,k) + Tensor tBsB = local_partition(sB, tB, threadIdx.x); // (THR_N,THR_K) + + CUTE_STATIC_ASSERT_V(size<0>(tAgA) == size<0>(tAsA)); // THR_M + CUTE_STATIC_ASSERT_V(size<1>(tAgA) == size<1>(tAsA)); // THR_K + CUTE_STATIC_ASSERT_V(size<0>(tBgB) == size<0>(tBsB)); // THR_N + CUTE_STATIC_ASSERT_V(size<1>(tBgB) == size<1>(tBsB)); // THR_K + + // + // Define A/B partitioning and C accumulators + // + + // TUTORIAL: Example of partitioning via projections of a ThreadLayout tC + + // Partition sA (M,K) by the rows of tC + Tensor tCsA = local_partition(sA, tC, threadIdx.x, Step<_1, X>{}); // (THR_M,BLK_K) + // Partition sB (N,K) by the cols of tC + Tensor tCsB = local_partition(sB, tC, threadIdx.x, Step< X,_1>{}); // (THR_N,BLK_K) + // Partition gC (M,N) by the tile of tC + Tensor tCgC = local_partition(gC, tC, threadIdx.x, Step<_1,_1>{}); // (THR_M,THR_N) + + // Allocate the accumulators -- same shape/layout as the partitioned data + Tensor tCrC = make_tensor_like(tCgC); // (THR_M,THR_N) + + CUTE_STATIC_ASSERT_V(size<0>(tCrC) == size<0>(tCgC)); // THR_M + CUTE_STATIC_ASSERT_V(size<0>(tCrC) == size<0>(tCsA)); // THR_M + CUTE_STATIC_ASSERT_V(size<1>(tCrC) == size<1>(tCgC)); // THR_N + CUTE_STATIC_ASSERT_V(size<1>(tCrC) == size<0>(tCsB)); // THR_N + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCsB)); // BLK_K + + // Clear the accumulators + clear(tCrC); + +#if 0 + if(thread0()) { + print(" mA : "); print( mA); print("\n"); + print(" gA : "); print( gA); print("\n"); + print(" sA : "); print( sA); print("\n"); + print("tAgA : "); print(tAgA); print("\n"); + print("tAsA : "); print(tAsA); print("\n"); + } +#endif + +#if 0 + if(thread0()) { + print(" mB : "); print( mB); print("\n"); + print(" gB : "); print( gB); print("\n"); + print(" sB : "); print( sB); print("\n"); + print("tBgB : "); print(tBgB); print("\n"); + print("tBsB : "); print(tBsB); print("\n"); + } +#endif + +#if 0 + if(thread0()) { + print(" mC : "); print( mC); print("\n"); + print(" gC : "); print( gC); print("\n"); + print("tCsA : "); print(tCsA); print("\n"); + print("tCsB : "); print(tCsB); print("\n"); + print("tCgC : "); print(tCgC); print("\n"); + print("tCrC : "); print(tCrC); print("\n"); + } +#endif + +#if 1 + + // TUTORIAL: Example of a simple mainloop that read tiles of data into shared memory, + // and then computes on those tiles. + // copy(.) operates on the global and shared memory via the tA|tB partitioning + // gemm(.) operates on the shared and register memory via the tC partitioning + + auto K_TILE_MAX = size<2>(tAgA); + + for (int k_tile = 0; k_tile < K_TILE_MAX; ++k_tile) + { + // Copy gmem to smem with tA|tB thread-partitioned tensors + copy(tAgA(_,_,k_tile), tAsA); // A (THR_M,THR_K) -> (THR_M,THR_K) + copy(tBgB(_,_,k_tile), tBsB); // B (THR_N,THR_K) -> (THR_N,THR_K) + + // TUTORIAL: The above call to copy(tAgA(_,_,k_tile), tAsA) is equivalent to + // Tensor tAgAk = tAgA(_,_,k_tile); + // CUTE_UNROLL + // for (int i = 0; i < size(tAsA); ++i) { + // tAsA(i) = tAgAk(i); + // } + + cp_async_fence(); // Label the end of (potential) cp.async instructions + cp_async_wait<0>(); // Sync on all (potential) cp.async instructions + __syncthreads(); // Wait for all threads to write to smem + + // Compute gemm on tC thread-partitioned smem + gemm(tCsA, tCsB, tCrC); // (THR_M,THR_N) += (THR_M,BLK_K) * (THR_N,BLK_K) + + // TUTORIAL: The above call to gemm(tCsA, tCsB, tCrC) is equivalent to + // CUTE_UNROLL + // for (int k = 0; k < size<1>(tCsA); ++k) { + // CUTE_UNROLL + // for (int m = 0; m < size<0>(tCrC); ++m) { + // CUTE_UNROLL + // for (int n = 0; n < size<1>(tCrC); ++n) { + // tCrC(m,n) += tCsA(m,k) * tCsB(n,k); + // } + // } + // } + + __syncthreads(); // Wait for all threads to read from smem + } + +#endif + + // + // Epilogue + // + + axpby(alpha, tCrC, beta, tCgC); + + // TUTORIAL: The above call to axpby(alpha, tCrC, beta, tCgC) is equivalent to + // CUTE_UNROLL + // for (int i = 0; i < size(tCsA); ++i) { + // tCgC(i) = alpha * tCrC(i) + beta * tCgC(i); + // } +} + +// Setup params for an NT GEMM +// Use m-major smem sA, n-major smem sB, and mn-major threads tA|tB +template +void +gemm_nt(int m, int n, int k, + Alpha alpha, + TA const* A, int ldA, + TB const* B, int ldB, + Beta beta, + TC * C, int ldC, + cudaStream_t stream = 0) +{ + using namespace cute; + + // Define shapes (dynamic) + auto M = int(m); + auto N = int(n); + auto K = int(k); + auto prob_shape = make_shape(M, N, K); // (M, N, K) + + // Define NT strides (mixed) + auto dA = make_stride(Int<1>{}, ldA); // (dM, dK) + auto dB = make_stride(Int<1>{}, ldB); // (dN, dK) + auto dC = make_stride(Int<1>{}, ldC); // (dM, dN) + + // Define CTA tile sizes (static) + auto bM = Int<128>{}; + auto bN = Int<128>{}; + auto bK = Int< 8>{}; + auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K) + + // Define the smem layouts (static) + auto sA = make_layout(make_shape(bM, bK)); // (m,k) -> smem_idx; m-major + auto sB = make_layout(make_shape(bN, bK)); // (n,k) -> smem_idx; n-major + auto sC = make_layout(make_shape(bM, bN)); // (m,n) -> smem_idx; m-major + + // Define the thread layouts (static) + auto tA = make_layout(make_shape(Int<32>{}, Int< 8>{})); // (m,k) -> thr_idx + auto tB = make_layout(make_shape(Int<32>{}, Int< 8>{})); // (n,k) -> thr_idx + auto tC = make_layout(make_shape(Int<16>{}, Int<16>{})); // (m,n) -> thr_idx + + dim3 dimBlock(size(tC)); + dim3 dimGrid(size(ceil_div(M, bM)), + size(ceil_div(N, bN))); + gemm_device<<>> + (prob_shape, cta_tiler, + A, dA, sA, tA, + B, dB, sB, tB, + C, dC, sC, tC, + alpha, beta); +} + +// Setup params for a TN GEMM +// Use padded m-major smem sA, padded n-major smem sB, and k-major threads tA|tB +template +void +gemm_tn(int m, int n, int k, + Alpha alpha, + TA const* A, int ldA, + TB const* B, int ldB, + Beta beta, + TC * C, int ldC, + cudaStream_t stream = 0) +{ + using namespace cute; + + // Define shapes (dynamic) + auto M = int(m); + auto N = int(n); + auto K = int(k); + auto prob_shape = make_shape(M, N, K); // (M, N, K) + + // Define TN strides (mixed) + auto dA = make_stride(ldA, Int<1>{}); // (dM, dK) + auto dB = make_stride(ldB, Int<1>{}); // (dN, dK) + auto dC = make_stride(Int<1>{}, ldC); // (dM, dN) + + // Define CTA tile sizes (static) + auto bM = Int<128>{}; + auto bN = Int<128>{}; + auto bK = Int< 8>{}; + auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K) + + // Define the smem layouts (static) + auto sA = make_layout(make_shape(bM,bK), LayoutRight{}); // (m,k) -> smem_idx; k-major + auto sB = make_layout(make_shape(bN,bK), LayoutRight{}); // (n,k) -> smem_idx; k-major + auto sC = make_layout(make_shape(bM, bN)); // (m,n) -> smem_idx; m-major + + // Define the thread layouts (static) + auto tA = make_layout(make_shape(Int<32>{}, Int< 8>{}), LayoutRight{}); // (m,k) -> thr_idx; k-major + auto tB = make_layout(make_shape(Int<32>{}, Int< 8>{}), LayoutRight{}); // (n,k) -> thr_idx; k-major + auto tC = make_layout(make_shape(Int<16>{}, Int<16>{})); // (m,n) -> thr_idx; m-major + + dim3 dimBlock(size(tC)); + dim3 dimGrid(size(ceil_div(M, bM)), + size(ceil_div(N, bN))); + gemm_device<<>> + (prob_shape, cta_tiler, + A, dA, sA, tA, + B, dB, sB, tB, + C, dC, sC, tC, + alpha, beta); +} + +template +void +gemm(char transA, char transB, int m, int n, int k, + Alpha alpha, + TA const* A, int ldA, + TB const* B, int ldB, + Beta beta, + TC * C, int ldC, + cudaStream_t stream = 0) +{ + if (transA == 'N' && transB == 'T') { + return gemm_nt(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC, stream); + } else + if (transA == 'T' && transB == 'N') { + return gemm_tn(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC, stream); + } + assert(false && "Not implemented"); +} + + +int main(int argc, char** argv) +{ + int m = 5120; + if (argc >= 2) + sscanf(argv[1], "%d", &m); + + int n = 5120; + if (argc >= 3) + sscanf(argv[2], "%d", &n); + + int k = 4096; + if (argc >= 4) + sscanf(argv[3], "%d", &k); + + char transA = 'N'; + if (argc >= 5) + sscanf(argv[4], "%c", &transA); + + char transB = 'T'; + if (argc >= 6) + sscanf(argv[5], "%c", &transB); + + using TA = float; + using TB = float; + using TC = float; + using TI = float; + + TI alpha = 1.0; + TI beta = 0.0; + + std::cout << "M = " << m << std::endl; + std::cout << "N = " << n << std::endl; + std::cout << "K = " << k << std::endl; + std::cout << "C = A^" << transA << " B^" << transB << std::endl; + + cute::device_init(0); + + thrust::host_vector h_A(m*k); + thrust::host_vector h_B(n*k); + thrust::host_vector h_C(m*n); + + for (int j = 0; j < m*k; ++j) h_A[j] = static_cast( 2*(rand() / double(RAND_MAX)) - 1 ); + for (int j = 0; j < n*k; ++j) h_B[j] = static_cast( 2*(rand() / double(RAND_MAX)) - 1 ); + for (int j = 0; j < m*n; ++j) h_C[j] = static_cast(-1); + + thrust::device_vector d_A = h_A; + thrust::device_vector d_B = h_B; + thrust::device_vector d_C = h_C; + + double gflops = (2.0*m*n*k) * 1e-9; + + const int timing_iterations = 100; + GPU_Clock timer; + + int ldA = 0, ldB = 0, ldC = m; + + if (transA == 'N') { + ldA = m; + } else if (transA == 'T') { + ldA = k; + } else { + assert(false); + } + + if (transB == 'N') { + ldB = k; + } else if (transB == 'T') { + ldB = n; + } else { + assert(false); + } + // Run once + d_C = h_C; + gemm(transA, transB, m, n, k, + alpha, + d_A.data().get(), ldA, + d_B.data().get(), ldB, + beta, + d_C.data().get(), ldC); + CUTE_CHECK_LAST(); + thrust::host_vector cute_result = d_C; + + // Timing iterations + timer.start(); + for (int i = 0; i < timing_iterations; ++i) { + gemm(transA, transB, m, n, k, + alpha, + d_A.data().get(), ldA, + d_B.data().get(), ldB, + beta, + d_C.data().get(), ldC); + } + double cute_time = timer.seconds() / timing_iterations; + CUTE_CHECK_LAST(); + printf("CUTE_GEMM: [%6.1f]GFlop/s (%6.4f)ms\n", gflops / cute_time, cute_time*1000); + return 0; +} diff --git a/examples/cute/tutorial/sgemm_2.cu b/examples/cute/tutorial/sgemm_2.cu new file mode 100644 index 0000000000..ee2b6b2e61 --- /dev/null +++ b/examples/cute/tutorial/sgemm_2.cu @@ -0,0 +1,523 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#include +#include +#include + +#include +#include + +#include + +#include "cutlass/util/print_error.hpp" +#include "cutlass/util/GPU_Clock.hpp" +#include "cutlass/util/helper_cuda.hpp" + +template +__global__ static +__launch_bounds__(decltype(size(TiledMma{}))::value) +void +gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler, + TA const* A, AStride dA, ASmemLayout sA_layout, TiledCopyA copy_a, + TB const* B, BStride dB, BSmemLayout sB_layout, TiledCopyB copy_b, + TC * C, CStride dC, CSmemLayout , TiledMma mma, + Alpha alpha, Beta beta) +{ + using namespace cute; + + // Preconditions + CUTE_STATIC_ASSERT_V(rank(shape_MNK) == Int<3>{}); // (M, N, K) + CUTE_STATIC_ASSERT_V(rank(cta_tiler) == Int<3>{}); // (BLK_M, BLK_N, BLK_K) + + CUTE_STATIC_ASSERT_V(size(copy_a) == size(mma)); // NumThreads + CUTE_STATIC_ASSERT_V(size(copy_b) == size(mma)); // NumThreads + + static_assert(is_static::value); + static_assert(is_static::value); + static_assert(is_static::value); + + CUTE_STATIC_ASSERT_V(size<0>(ASmemLayout{}) == size<0>(cta_tiler)); // BLK_M + CUTE_STATIC_ASSERT_V(size<0>(CSmemLayout{}) == size<0>(cta_tiler)); // BLK_M + CUTE_STATIC_ASSERT_V(size<0>(BSmemLayout{}) == size<1>(cta_tiler)); // BLK_N + CUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<1>(cta_tiler)); // BLK_N + CUTE_STATIC_ASSERT_V(size<1>(ASmemLayout{}) == size<2>(cta_tiler)); // BLK_K + CUTE_STATIC_ASSERT_V(size<1>(BSmemLayout{}) == size<2>(cta_tiler)); // BLK_K + + CUTE_STATIC_ASSERT_V(congruent(select<0,2>(shape_MNK), dA)); // dA strides for shape MK + CUTE_STATIC_ASSERT_V(congruent(select<1,2>(shape_MNK), dB)); // dB strides for shape NK + CUTE_STATIC_ASSERT_V(congruent(select<0,1>(shape_MNK), dC)); // dC strides for shape MN + + // + // Full and Tiled Tensors + // + + // Represent the full tensors + Tensor mA = make_tensor(make_gmem_ptr(A), select<0,2>(shape_MNK), dA); // (M,K) + Tensor mB = make_tensor(make_gmem_ptr(B), select<1,2>(shape_MNK), dB); // (N,K) + Tensor mC = make_tensor(make_gmem_ptr(C), select<0,1>(shape_MNK), dC); // (M,N) + + // Get the appropriate blocks for this thread block + auto cta_coord = make_coord(blockIdx.x, blockIdx.y, _); // (m,n,k) + Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,k) + Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) + Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1,_1, X>{}); // (BLK_M,BLK_N) + + // Shared memory buffers + __shared__ TA smemA[cosize_v]; + __shared__ TB smemB[cosize_v]; + Tensor sA = make_tensor(make_smem_ptr(smemA), sA_layout); // (BLK_M,BLK_K) + Tensor sB = make_tensor(make_smem_ptr(smemB), sB_layout); // (BLK_N,BLK_K) + + // + // Partition the copying of A and B tiles across the threads + // + + // TUTORIAL: Example of partitioning via a TiledCopy + + ThrCopy thr_copy_a = copy_a.get_slice(threadIdx.x); + Tensor tAgA = thr_copy_a.partition_S(gA); // (CPY,CPY_M,CPY_K,k) + Tensor tAsA = thr_copy_a.partition_D(sA); // (CPY,CPY_M,CPY_K) + // Allocate registers same shape/layout as partitioned data + Tensor tArA = make_fragment_like(tAsA); // (CPY,CPY_M,CPY_K) + + ThrCopy thr_copy_b = copy_b.get_slice(threadIdx.x); + Tensor tBgB = thr_copy_b.partition_S(gB); // (CPY,CPY_N,CPY_K,k) + Tensor tBsB = thr_copy_b.partition_D(sB); // (CPY,CPY_N,CPY_K) + // Allocate registers same shape/layout as partitioned data + Tensor tBrB = make_fragment_like(tBsB); // (CPY,CPY_N,CPY_K) + + CUTE_STATIC_ASSERT_V(size<1>(tAgA) == size<1>(tAsA)); // CPY_M + CUTE_STATIC_ASSERT_V(size<1>(tAgA) == size<1>(tArA)); // CPY_M + CUTE_STATIC_ASSERT_V(size<2>(tAgA) == size<2>(tAsA)); // CPY_K + CUTE_STATIC_ASSERT_V(size<2>(tAgA) == size<2>(tArA)); // CPY_K + CUTE_STATIC_ASSERT_V(size<1>(tBgB) == size<1>(tBsB)); // CPY_N + CUTE_STATIC_ASSERT_V(size<1>(tBgB) == size<1>(tBrB)); // CPY_N + CUTE_STATIC_ASSERT_V(size<2>(tBgB) == size<2>(tBsB)); // CPY_K + CUTE_STATIC_ASSERT_V(size<2>(tBgB) == size<2>(tBrB)); // CPY_K + + // Copy gmem to rmem for k_tile=0 + copy(copy_a, tAgA(_,_,_,0), tArA); + copy(copy_b, tBgB(_,_,_,0), tBrB); + // + // Define A/B partitioning and C accumulators + // + + // TUTORIAL: Example of partitioning via a TiledMMA + + ThrMMA thr_mma = mma.get_slice(threadIdx.x); + Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K) + Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K) + Tensor tCgC = thr_mma.partition_C(gC); // (MMA,MMA_M,MMA_N) + + // Allocate the accumulators -- same size as the projected data + Tensor tCrC = thr_mma.make_fragment_C(tCgC); // (MMA,MMA_M,MMA_N) + + CUTE_STATIC_ASSERT_V( shape(tCrC) == shape(tCgC)); // (MMA,MMA_M,MMA_N) + CUTE_STATIC_ASSERT_V(size<1>(tCgC) == size<1>(tCsA)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(tCgC) == size<1>(tCsB)); // MMA_N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // MMA_K + + // Clear the accumulators + clear(tCrC); + +#if 0 + if(thread0()) { + print(" mA : "); print( mA); print("\n"); + print(" gA : "); print( gA); print("\n"); + print(" sA : "); print( sA); print("\n"); + print("tAgA : "); print(tAgA); print("\n"); + print("tAsA : "); print(tAsA); print("\n"); + print("tArA : "); print(tArA); print("\n"); + } +#endif + +#if 0 + if(thread0()) { + print(" mB : "); print( mB); print("\n"); + print(" gB : "); print( gB); print("\n"); + print(" sB : "); print( sB); print("\n"); + print("tBgB : "); print(tBgB); print("\n"); + print("tBsB : "); print(tBsB); print("\n"); + print("tArA : "); print(tArA); print("\n"); + } +#endif + +#if 0 + if(thread0()) { + print(" mC : "); print( mC); print("\n"); + print(" gC : "); print( gC); print("\n"); + print("tCsA : "); print(tCsA); print("\n"); + print("tCsB : "); print(tCsB); print("\n"); + print("tCgC : "); print(tCgC); print("\n"); + print("tCrC : "); print(tCrC); print("\n"); + } +#endif + +#if 1 + + // TUTORIAL: Example of an inner loop that pipelines compute with reads + // from global memory by staging through register and shared memory. + // Data is read from global to registers, then to shared via the TiledCopy partitions + // gemm(.) operates on the shared memory directly via the TiledMMA partitions + + auto K_TILE_MAX = size<3>(tAgA); + + for (int k_tile = 0; k_tile < K_TILE_MAX; ++k_tile) + { + // Copy rmem to smem with tA|tB thread-partitioned tensors + __syncthreads(); // Wait for all threads to consume smem + copy(tArA, tAsA); + copy(tBrB, tBsB); + __syncthreads(); // Wait for all threads to consume smem + + // Copy gmem to rmem for k_tile+1 with tA|tB thread-partitioned tensors + int k_tile_next = (k_tile + 1 < K_TILE_MAX) ? k_tile + 1 : k_tile; + copy(copy_a, tAgA(_,_,_,k_tile_next), tArA); + copy(copy_b, tBgB(_,_,_,k_tile_next), tBrB); + // TUTORIAL: The above call to copy(copy_a, tAgA(_,_,_,k_tile_next), tArA) is equivalent to + // CUTE_UNROLL + // for (int k = 0; k < size<1>(tCsA); ++k) { + // CUTE_UNROLL + // for (int m = 0; m < size<0>(tCrC); ++m) { + // copy_a.call(tAgA(_,m,k), tArA(_,m,k); + // } + // } + + // Compute gemm on mma-partitioned smem + gemm(mma, tCsA, tCsB, tCrC); + // TUTORIAL: The above call to gemm(tCsA, tCsB, tCrC) is equivalent to + // CUTE_UNROLL + // for (int k = 0; k < size<1>(tCsA); ++k) { + // CUTE_UNROLL + // for (int m = 0; m < size<0>(tCrC); ++m) { + // CUTE_UNROLL + // for (int n = 0; n < size<1>(tCrC); ++n) { + // mma.call(tCsA(_,m,k), tCsB(_,n,k), tCrC(_,m,n); + // } + // } + // } + } + +#endif + + // + // Epilogue + // + + axpby(alpha, tCrC, beta, tCgC); +} + +// Setup params for a NT GEMM +template +void +gemm_nt(int m, int n, int k, + Alpha alpha, + TA const* A, int ldA, + TB const* B, int ldB, + Beta beta, + TC * C, int ldC, + cudaStream_t stream = 0) +{ + using namespace cute; + + // Define shapes (dynamic) + auto M = int(m); + auto N = int(n); + auto K = int(k); + auto prob_shape = make_shape(M, N, K); // (M, N, K) + + // Define NT strides (mixed) + auto dA = make_stride(Int<1>{}, ldA); // (dM, dK) + auto dB = make_stride(Int<1>{}, ldB); // (dN, dK) + auto dC = make_stride(Int<1>{}, ldC); // (dM, dN) + + // Define CTA tile sizes (static) + auto bM = Int<128>{}; + auto bN = Int<128>{}; + auto bK = Int< 8>{}; + auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K) + + // Define the smem layouts (static) + auto sA = make_layout(make_shape(bM, bK)); // (m,k) -> smem_idx; m-major + auto sB = make_layout(make_shape(bN, bK)); // (n,k) -> smem_idx; n-major + auto sC = make_layout(make_shape(bM, bN)); // (m,n) -> smem_idx; m-major + + // Define the thread layouts (static) + + // TUTORIAL: Construct TiledCopy with a particular Copy_Atom to use and + // define the partitioning pattern to apply. + // Each thread will (try to) copy 4x1 elements of type TA using 128-bit copy. + // Use 32x8 of these threads. + + TiledCopy copyA = make_tiled_copy(Copy_Atom, TA>{}, + Layout>{}, // Thr layout 32x8 m-major + Layout>{}); // Val layout 4x1 m-major + TiledCopy copyB = make_tiled_copy(Copy_Atom, TB>{}, + Layout>{}, // Thr layout 32x8 n-major + Layout>{}); // Val layout 4x1 n-major + + // TUTORIAL: Construct TiledMMA with a particular MMA_Atom to use and + // define the partitioning pattern to apply. + // Use a 1x1x1 FMA on the types TC += TA * TB. Each atom requires a single thread. + // Reproduce that atom 16x16x1 times (m-major) across threads so that we use 256 threads. + + TiledMMA mmaC = make_tiled_mma(UniversalFMA{}, + Layout>{}); // 16x16x1 UniversalFMA + +#if 0 + print(copyA); + print(copyB); + print(mmaC); +#endif + +#if 0 + print_latex(copyA); + print_latex(copyB); + print_latex(mmaC); +#endif + + dim3 dimBlock(size(mmaC)); + dim3 dimGrid(size(ceil_div(M, bM)), + size(ceil_div(N, bN))); + gemm_device<<>> + (prob_shape, cta_tiler, + A, dA, sA, copyA, + B, dB, sB, copyB, + C, dC, sC, mmaC, + alpha, beta); +} + +// Setup params for a TN GEMM +template +void +gemm_tn(int m, int n, int k, + Alpha alpha, + TA const* A, int ldA, + TB const* B, int ldB, + Beta beta, + TC * C, int ldC, + cudaStream_t stream = 0) +{ + using namespace cute; + + // Define shapes (dynamic) + auto M = int(m); + auto N = int(n); + auto K = int(k); + auto prob_shape = make_shape(M, N, K); // (M, N, K) + + // Define TN strides (mixed) + auto dA = make_stride(ldA, Int<1>{}); // (dM, dK) + auto dB = make_stride(ldB, Int<1>{}); // (dN, dK) + auto dC = make_stride(Int<1>{}, ldC); // (dM, dN) + + // Define CTA tile sizes (static) + auto bM = Int<128>{}; + auto bN = Int<128>{}; + auto bK = Int< 8>{}; + auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K) + + // Define the smem layouts (static) + auto sA = make_layout(make_shape ( bM, bK), + make_stride(Int<1>{}, bM+Int<1>{})); // (m,k) -> smem_idx; padded m-major + auto sB = make_layout(make_shape ( bN, bK), + make_stride(Int<1>{}, bN+Int<1>{})); // (n,k) -> smem_idx; padded n-major + auto sC = make_layout(make_shape(bM, bN)); // (m,n) -> smem_idx + + // TUTORIAL: Construct TiledCopy to define the Copy_Atom to use and the + // partitioning pattern to apply. + // Each thread will copy 1x1 elements of type TA. + // Use 32x8 of these threads arranged in k-major. + + TiledCopy copyA = make_tiled_copy(Copy_Atom, TA>{}, + Layout,Stride<_8,_1>>{}, // Thr layout 32x8 k-major + Layout>{}); // Val layout 1x1 + TiledCopy copyB = make_tiled_copy(Copy_Atom, TB>{}, + Layout,Stride<_8,_1>>{}, // Thr layout 32x8 k-major + Layout>{}); // Val layout 1x1 + + // TUTORIAL: Construct TiledMMA to define the MMA_Atom to use and the + // partitioning pattern to apply. + // Use a 1x1x1 FMA on the types TC += TA * TB. Each atom requires a single thread. + // Reproduce that atom 16x16x1 times (m-major) across threads so that we use 256 threads. + + TiledMMA mmaC = make_tiled_mma(UniversalFMA{}, + Layout>{}); // 16x16x1 TiledMMA + +#if 0 + print(copyA); + print(copyB); + print(mmaC); +#endif + +#if 0 + print_latex(copyA); + print_latex(copyB); + print_latex(mmaC); +#endif + + dim3 dimBlock(size(mmaC)); + dim3 dimGrid(size(ceil_div(M, bM)), + size(ceil_div(N, bN))); + gemm_device<<>> + (prob_shape, cta_tiler, + A, dA, sA, copyA, + B, dB, sB, copyB, + C, dC, sC, mmaC, + alpha, beta); +} + +template +void +gemm(char transA, char transB, int m, int n, int k, + Alpha alpha, + TA const* A, int ldA, + TB const* B, int ldB, + Beta beta, + TC * C, int ldC, + cudaStream_t stream = 0) +{ + if (transA == 'N' && transB == 'T') { + return gemm_nt(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC, stream); + } else + if (transA == 'T' && transB == 'N') { + return gemm_tn(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC, stream); + } + assert(false && "Not implemented"); +} + + +int main(int argc, char** argv) +{ + int m = 5120; + if (argc >= 2) + sscanf(argv[1], "%d", &m); + + int n = 5120; + if (argc >= 3) + sscanf(argv[2], "%d", &n); + + int k = 4096; + if (argc >= 4) + sscanf(argv[3], "%d", &k); + + char transA = 'N'; + if (argc >= 5) + sscanf(argv[4], "%c", &transA); + + char transB = 'T'; + if (argc >= 6) + sscanf(argv[5], "%c", &transB); + + using TA = float; + using TB = float; + using TC = float; + using TI = float; + + TI alpha = 1.0; + TI beta = 0.0; + + std::cout << "M = " << m << std::endl; + std::cout << "N = " << n << std::endl; + std::cout << "K = " << k << std::endl; + std::cout << "C = A^" << transA << " B^" << transB << std::endl; + + cute::device_init(0); + + thrust::host_vector h_A(m*k); + thrust::host_vector h_B(n*k); + thrust::host_vector h_C(m*n); + + for (int j = 0; j < m*k; ++j) h_A[j] = static_cast( 2*(rand() / double(RAND_MAX)) - 1 ); + for (int j = 0; j < n*k; ++j) h_B[j] = static_cast( 2*(rand() / double(RAND_MAX)) - 1 ); + for (int j = 0; j < m*n; ++j) h_C[j] = static_cast(-1); + + thrust::device_vector d_A = h_A; + thrust::device_vector d_B = h_B; + thrust::device_vector d_C = h_C; + + double gflops = (2.0*m*n*k) * 1e-9; + + const int timing_iterations = 100; + GPU_Clock timer; + + int ldA = 0, ldB = 0, ldC = m; + + if (transA == 'N') { + ldA = m; + } else if (transA == 'T') { + ldA = k; + } else { + assert(false); + } + + if (transB == 'N') { + ldB = k; + } else if (transB == 'T') { + ldB = n; + } else { + assert(false); + } + + // Run once + d_C = h_C; + gemm(transA, transB, m, n, k, + alpha, + d_A.data().get(), ldA, + d_B.data().get(), ldB, + beta, + d_C.data().get(), ldC); + CUTE_CHECK_LAST(); + thrust::host_vector cute_result = d_C; + + // Timing iterations + timer.start(); + for (int i = 0; i < timing_iterations; ++i) { + gemm(transA, transB, m, n, k, + alpha, + d_A.data().get(), ldA, + d_B.data().get(), ldB, + beta, + d_C.data().get(), ldC); + } + double cute_time = timer.seconds() / timing_iterations; + CUTE_CHECK_LAST(); + printf("CUTE_GEMM: [%6.1f]GFlop/s (%6.4f)ms\n", gflops / cute_time, cute_time*1000); + + return 0; +} diff --git a/examples/cute/tutorial/sgemm_sm70.cu b/examples/cute/tutorial/sgemm_sm70.cu new file mode 100644 index 0000000000..ef6284cf00 --- /dev/null +++ b/examples/cute/tutorial/sgemm_sm70.cu @@ -0,0 +1,526 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#include +#include +#include + +#include +#include + +#include + +#include "cutlass/util/print_error.hpp" +#include "cutlass/util/GPU_Clock.hpp" +#include "cutlass/util/helper_cuda.hpp" + +template +__global__ static +__launch_bounds__(decltype(size(TiledMma{}))::value) +void +gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler, + TA const* A, AStride dA, ASmemLayout sA_layout, TiledCopyA copy_a, + TB const* B, BStride dB, BSmemLayout sB_layout, TiledCopyB copy_b, + TC * C, CStride dC, CSmemLayout , TiledMma mma, + Alpha alpha, Beta beta) +{ + using namespace cute; + + // Preconditions + CUTE_STATIC_ASSERT_V(rank(shape_MNK) == Int<3>{}); // (M, N, K) + CUTE_STATIC_ASSERT_V(rank(cta_tiler) == Int<3>{}); // (BLK_M, BLK_N, BLK_K) + + CUTE_STATIC_ASSERT_V(size(copy_a) == size(mma)); // NumThreads + CUTE_STATIC_ASSERT_V(size(copy_b) == size(mma)); // NumThreads + + static_assert(is_static::value); + static_assert(is_static::value); + static_assert(is_static::value); + + CUTE_STATIC_ASSERT_V(size<0>(ASmemLayout{}) == size<0>(cta_tiler)); // BLK_M + CUTE_STATIC_ASSERT_V(size<0>(CSmemLayout{}) == size<0>(cta_tiler)); // BLK_M + CUTE_STATIC_ASSERT_V(size<0>(BSmemLayout{}) == size<1>(cta_tiler)); // BLK_N + CUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<1>(cta_tiler)); // BLK_N + CUTE_STATIC_ASSERT_V(size<1>(ASmemLayout{}) == size<2>(cta_tiler)); // BLK_K + CUTE_STATIC_ASSERT_V(size<1>(BSmemLayout{}) == size<2>(cta_tiler)); // BLK_K + + CUTE_STATIC_ASSERT_V(congruent(select<0,2>(shape_MNK), dA)); // dA strides for shape MK + CUTE_STATIC_ASSERT_V(congruent(select<1,2>(shape_MNK), dB)); // dB strides for shape NK + CUTE_STATIC_ASSERT_V(congruent(select<0,1>(shape_MNK), dC)); // dC strides for shape MN + + // + // Full and Tiled Tensors + // + + // Represent the full tensors + Tensor mA = make_tensor(make_gmem_ptr(A), select<0,2>(shape_MNK), dA); // (M,K) + Tensor mB = make_tensor(make_gmem_ptr(B), select<1,2>(shape_MNK), dB); // (N,K) + Tensor mC = make_tensor(make_gmem_ptr(C), select<0,1>(shape_MNK), dC); // (M,N) + + // Get the appropriate blocks for this thread block + auto cta_coord = make_coord(blockIdx.x, blockIdx.y, _); // (m,n,k) + Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,k) + Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) + Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1,_1, X>{}); // (BLK_M,BLK_N) + + // Shared memory buffers + __shared__ TA smemA[cosize_v]; + __shared__ TB smemB[cosize_v]; + Tensor sA = make_tensor(make_smem_ptr(smemA), sA_layout); // (BLK_M,BLK_K) + Tensor sB = make_tensor(make_smem_ptr(smemB), sB_layout); // (BLK_N,BLK_K) + + // + // Partition the copying of A and B tiles across the threads + // + + // TUTORIAL: Example of partitioning via a TiledCopy + + ThrCopy thr_copy_a = copy_a.get_slice(threadIdx.x); + Tensor tAgA = thr_copy_a.partition_S(gA); // (CPY,CPY_M,CPY_K,k) + Tensor tAsA = thr_copy_a.partition_D(sA); // (CPY,CPY_M,CPY_K) + Tensor tArA = make_fragment_like(tAsA); // (CPY,CPY_M,CPY_K) + + ThrCopy thr_copy_b = copy_b.get_slice(threadIdx.x); + Tensor tBgB = thr_copy_b.partition_S(gB); // (CPY,CPY_N,CPY_K,k) + Tensor tBsB = thr_copy_b.partition_D(sB); // (CPY,CPY_N,CPY_K) + Tensor tBrB = make_fragment_like(tBsB); // (CPY,CPY_N,CPY_K) + + CUTE_STATIC_ASSERT_V(size<1>(tAgA) == size<1>(tAsA)); // CPY_M + CUTE_STATIC_ASSERT_V(size<1>(tAgA) == size<1>(tArA)); // CPY_M + CUTE_STATIC_ASSERT_V(size<2>(tAgA) == size<2>(tAsA)); // CPY_K + CUTE_STATIC_ASSERT_V(size<2>(tAgA) == size<2>(tArA)); // CPY_K + CUTE_STATIC_ASSERT_V(size<1>(tBgB) == size<1>(tBsB)); // CPY_N + CUTE_STATIC_ASSERT_V(size<1>(tBgB) == size<1>(tBrB)); // CPY_N + CUTE_STATIC_ASSERT_V(size<2>(tBgB) == size<2>(tBsB)); // CPY_K + CUTE_STATIC_ASSERT_V(size<2>(tBgB) == size<2>(tBrB)); // CPY_K + + // Copy gmem to rmem for k_tile=0 + copy(copy_a, tAgA(_,_,_,0), tArA); + copy(copy_b, tBgB(_,_,_,0), tBrB); + // + // Define A/B partitioning and C accumulators + // + + // TUTORIAL: Example of partitioning via a TiledMMA + + ThrMMA thr_mma = mma.get_slice(threadIdx.x); + Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K) + Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K) + Tensor tCgC = thr_mma.partition_C(gC); // (MMA,MMA_M,MMA_N) + + // Allocate registers for pipelining + Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K) + Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K) + // Allocate the accumulators -- same size as the projected data + Tensor tCrC = thr_mma.make_fragment_C(tCgC); // (MMA,MMA_M,MMA_N) + + CUTE_STATIC_ASSERT_V( shape(tCrA) == shape(tCsA)); // (MMA,MMA_M,MMA_K) + CUTE_STATIC_ASSERT_V( shape(tCrB) == shape(tCsB)); // (MMA,MMA_N,MMA_K) + CUTE_STATIC_ASSERT_V( shape(tCrC) == shape(tCgC)); // (MMA,MMA_M,MMA_N) + CUTE_STATIC_ASSERT_V(size<1>(tCgC) == size<1>(tCsA)); // MMA_M + CUTE_STATIC_ASSERT_V(size<2>(tCgC) == size<1>(tCsB)); // MMA_N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // MMA_K + + // Clear the accumulators + clear(tCrC); + +#if 0 + if(thread0()) { + print(" mA : "); print( mA); print("\n"); + print(" gA : "); print( gA); print("\n"); + print(" sA : "); print( sA); print("\n"); + print("tAgA : "); print(tAgA); print("\n"); + print("tAsA : "); print(tAsA); print("\n"); + print("tArA : "); print(tArA); print("\n"); + } +#endif + +#if 0 + if(thread0()) { + print(" mB : "); print( mB); print("\n"); + print(" gB : "); print( gB); print("\n"); + print(" sB : "); print( sB); print("\n"); + print("tBgB : "); print(tBgB); print("\n"); + print("tBsB : "); print(tBsB); print("\n"); + print("tArA : "); print(tArA); print("\n"); + } +#endif + +#if 0 + if(thread0()) { + print(" mC : "); print( mC); print("\n"); + print(" gC : "); print( gC); print("\n"); + print("tCsA : "); print(tCsA); print("\n"); + print("tCsB : "); print(tCsB); print("\n"); + print("tCgC : "); print(tCgC); print("\n"); + print("tCrC : "); print(tCrC); print("\n"); + } +#endif + +#if 1 + + // Copy rmem to smem + copy(tArA, tAsA); + copy(tBrB, tBsB); + __syncthreads(); + + // + // PIPELINED MAIN LOOP + // TUTORIAL: Example of a gemm loop that pipelines shared memory AND register memory + // Data is read from global to registers, then to shared via the tA|tB partitions + // Data is then copied from shared to registers in multiple waves via the tC partitions + // and gemm(.) operates on the current register wave + // + + // Load A, B shmem->regs for k_block=0 + copy(tCsA(_,_,0), tCrA(_,_,0)); + copy(tCsB(_,_,0), tCrB(_,_,0)); + auto K_TILE_MAX = size<3>(tAgA); + auto K_BLOCK_MAX = size<2>(tCrA); + + CUTE_NO_UNROLL + for (int k_tile = 0; k_tile < K_TILE_MAX; ++k_tile) + { + // Pipeline the k-mode of the block registers + CUTE_UNROLL + for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) + { + if (k_block == K_BLOCK_MAX - 1) + { + // Copy rmem to smem + __syncthreads(); + copy(tArA, tAsA); + copy(tBrB, tBsB); + __syncthreads(); + } + + // Copy smem to rmem for k_block+1 + int k_block_next = (k_block + 1) % K_BLOCK_MAX; + copy(tCsA(_,_,k_block_next), tCrA(_,_,k_block_next)); + copy(tCsB(_,_,k_block_next), tCrB(_,_,k_block_next)); + if (k_block == 0) + { + // Copy gmem to rmem for k_tile+1 + int k_tile_next = (k_tile + 1 < K_TILE_MAX) ? k_tile + 1 : k_tile; + copy(copy_a, tAgA(_,_,_,k_tile_next), tArA); + copy(copy_b, tBgB(_,_,_,k_tile_next), tBrB); + } + // Thread-level register gemm for k_block + gemm(mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); + } // k_block + } // k_tile + +#endif + + // + // Epilogue + // + + axpby(alpha, tCrC, beta, tCgC); +} + +// Setup params for a NT GEMM +template +void +gemm_nt(int m, int n, int k, + Alpha alpha, + TA const* A, int ldA, + TB const* B, int ldB, + Beta beta, + TC * C, int ldC, + cudaStream_t stream = 0) +{ + using namespace cute; + + // Define shapes (dynamic) + auto M = int(m); + auto N = int(n); + auto K = int(k); + auto prob_shape = make_shape(M, N, K); // (M, N, K) + + // Define NT strides (mixed) + auto dA = make_stride(Int<1>{}, ldA); // (dM, dK) + auto dB = make_stride(Int<1>{}, ldB); // (dN, dK) + auto dC = make_stride(Int<1>{}, ldC); // (dM, dN) + + // Define CTA tile sizes (static) + auto bM = Int<128>{}; + auto bN = Int<128>{}; + auto bK = Int< 8>{}; + auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K) + + // Define the smem layouts (static) + auto sA = make_layout(make_shape(bM, bK)); // (m,k) -> smem_idx; m-major + auto sB = make_layout(make_shape(bN, bK)); // (n,k) -> smem_idx; n-major + auto sC = make_layout(make_shape(bM, bN)); // (m,n) -> smem_idx; m-major + + // Define the thread layouts (static) + TiledCopy copyA = make_tiled_copy(Copy_Atom, TA>{}, + Layout>{}, // Thr layout 32x8 m-major + Layout>{}); // Val layout 4x1 m-major + TiledCopy copyB = make_tiled_copy(Copy_Atom, TB>{}, + Layout>{}, // Thr layout 32x8 n-major + Layout>{}); // Val layout 4x1 n-major + + TiledMMA mmaC = make_tiled_mma(UniversalFMA{}, + Layout>{}); // 16x16x1 TiledMMA + +#if 0 + print(copyA); + print(copyB); + print(mmaC); +#endif + +#if 0 + print_latex(copyA); + print_latex(copyB); + print_latex(mmaC); +#endif + + dim3 dimBlock(size(mmaC)); + dim3 dimGrid(size(ceil_div(M, bM)), + size(ceil_div(N, bN))); + gemm_device<<>> + (prob_shape, cta_tiler, + A, dA, sA, copyA, + B, dB, sB, copyB, + C, dC, sC, mmaC, + alpha, beta); +} + +// Setup params for a TN GEMM +template +void +gemm_tn(int m, int n, int k, + Alpha alpha, + TA const* A, int ldA, + TB const* B, int ldB, + Beta beta, + TC * C, int ldC, + cudaStream_t stream = 0) +{ + using namespace cute; + + // Define shapes (dynamic) + auto M = int(m); + auto N = int(n); + auto K = int(k); + auto prob_shape = make_shape(M, N, K); // (M, N, K) + + // Define TN strides (mixed) + auto dA = make_stride(ldA, Int<1>{}); // (dM, dK) + auto dB = make_stride(ldB, Int<1>{}); // (dN, dK) + auto dC = make_stride(Int<1>{}, ldC); // (dM, dN) + + // Define CTA tile sizes (static) + auto bM = Int<128>{}; + auto bN = Int<128>{}; + auto bK = Int< 8>{}; + auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K) + + // Define the smem layouts (static) + auto sA = make_layout(make_shape ( bM, bK), + make_stride(Int<1>{}, bM+Int<1>{})); // (m,k) -> smem_idx; padded m-major + auto sB = make_layout(make_shape ( bN, bK), + make_stride(Int<1>{}, bN+Int<1>{})); // (n,k) -> smem_idx; padded n-major + auto sC = make_layout(make_shape(bM, bN)); // (m,n) -> smem_idx + + // Define the thread layouts (static) + + TiledCopy copyA = make_tiled_copy(Copy_Atom, TA>{}, + Layout,Stride<_8,_1>>{}, // Thr layout 32x8 k-major + Layout>{}); // Val layout 1x1 + TiledCopy copyB = make_tiled_copy(Copy_Atom, TB>{}, + Layout,Stride<_8,_1>>{}, // Thr layout 32x8 k-major + Layout>{}); // Val layout 1x1 + + TiledMMA mmaC = make_tiled_mma(UniversalFMA{}, + Layout>{}); // 16x16x1 TiledMMA + +#if 0 + print(copyA); + print(copyB); + print(mmaC); +#endif + +#if 0 + print_latex(copyA); + print_latex(copyB); + print_latex(mmaC); +#endif + + dim3 dimBlock(size(mmaC)); + dim3 dimGrid(size(ceil_div(M, bM)), + size(ceil_div(N, bN))); + gemm_device<<>> + (prob_shape, cta_tiler, + A, dA, sA, copyA, + B, dB, sB, copyB, + C, dC, sC, mmaC, + alpha, beta); +} + +template +void +gemm(char transA, char transB, int m, int n, int k, + Alpha alpha, + TA const* A, int ldA, + TB const* B, int ldB, + Beta beta, + TC * C, int ldC, + cudaStream_t stream = 0) +{ + if (transA == 'N' && transB == 'T') { + return gemm_nt(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC, stream); + } else + if (transA == 'T' && transB == 'N') { + return gemm_tn(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC, stream); + } + assert(false && "Not implemented"); +} + + +int main(int argc, char** argv) +{ + cudaDeviceProp props; + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if (props.major < 7) { + std::cout << "This example requires an Volta GPU or newer (CC >= 70)" << std::endl; + // Return 0 so tests pass if run on unsupported architectures or CUDA Toolkits. + return 0; + } + + int m = 5120; + if (argc >= 2) + sscanf(argv[1], "%d", &m); + + int n = 5120; + if (argc >= 3) + sscanf(argv[2], "%d", &n); + + int k = 4096; + if (argc >= 4) + sscanf(argv[3], "%d", &k); + + char transA = 'N'; + if (argc >= 5) + sscanf(argv[4], "%c", &transA); + + char transB = 'T'; + if (argc >= 6) + sscanf(argv[5], "%c", &transB); + + using TA = float; + using TB = float; + using TC = float; + using TI = float; + + TI alpha = 1.0; + TI beta = 0.0; + + std::cout << "M = " << m << std::endl; + std::cout << "N = " << n << std::endl; + std::cout << "K = " << k << std::endl; + std::cout << "C = A^" << transA << " B^" << transB << std::endl; + + thrust::host_vector h_A(m*k); + thrust::host_vector h_B(n*k); + thrust::host_vector h_C(m*n); + + for (int j = 0; j < m*k; ++j) h_A[j] = static_cast( 2*(rand() / double(RAND_MAX)) - 1 ); + for (int j = 0; j < n*k; ++j) h_B[j] = static_cast( 2*(rand() / double(RAND_MAX)) - 1 ); + for (int j = 0; j < m*n; ++j) h_C[j] = static_cast(-1); + + thrust::device_vector d_A = h_A; + thrust::device_vector d_B = h_B; + thrust::device_vector d_C = h_C; + + double gflops = (2.0*m*n*k) * 1e-9; + + const int timing_iterations = 100; + GPU_Clock timer; + + int ldA = 0, ldB = 0, ldC = m; + + if (transA == 'N') { + ldA = m; + } else if (transA == 'T') { + ldA = k; + } else { + assert(false); + } + + if (transB == 'N') { + ldB = k; + } else if (transB == 'T') { + ldB = n; + } else { + assert(false); + } + + // Run once + d_C = h_C; + gemm(transA, transB, m, n, k, + alpha, + d_A.data().get(), ldA, + d_B.data().get(), ldB, + beta, + d_C.data().get(), ldC); + CUTE_CHECK_LAST(); + thrust::host_vector cute_result = d_C; + + // Timing iterations + timer.start(); + for (int i = 0; i < timing_iterations; ++i) { + gemm(transA, transB, m, n, k, + alpha, + d_A.data().get(), ldA, + d_B.data().get(), ldB, + beta, + d_C.data().get(), ldC); + } + double cute_time = timer.seconds() / timing_iterations; + CUTE_CHECK_LAST(); + printf("CUTE_GEMM: [%6.1f]GFlop/s (%6.4f)ms\n", gflops / cute_time, cute_time*1000); + + return 0; +} diff --git a/examples/cute/tutorial/sgemm_sm80.cu b/examples/cute/tutorial/sgemm_sm80.cu new file mode 100644 index 0000000000..5ae0bf0f8b --- /dev/null +++ b/examples/cute/tutorial/sgemm_sm80.cu @@ -0,0 +1,567 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#include +#include +#include + +#include +#include + +#include + +#include "cutlass/util/print_error.hpp" +#include "cutlass/util/GPU_Clock.hpp" +#include "cutlass/util/helper_cuda.hpp" + +template +__global__ static +__launch_bounds__(decltype(size(TiledMma{}))::value) +void +gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler, + TA const* A, AStride dA, ASmemLayout sA_layout, TiledCopyA copy_a, + TB const* B, BStride dB, BSmemLayout sB_layout, TiledCopyB copy_b, + TC * C, CStride dC, CSmemLayout , TiledMma mma, + Alpha alpha, Beta beta) +{ + using namespace cute; + + // Preconditions + CUTE_STATIC_ASSERT_V(rank(shape_MNK) == Int<3>{}); // (M, N, K) + CUTE_STATIC_ASSERT_V(rank(cta_tiler) == Int<3>{}); // (BLK_M, BLK_N, BLK_K) + + CUTE_STATIC_ASSERT_V(size(copy_a) == size(mma)); // NumThreads + CUTE_STATIC_ASSERT_V(size(copy_b) == size(mma)); // NumThreads + + static_assert(is_static::value); + static_assert(is_static::value); + static_assert(is_static::value); + + CUTE_STATIC_ASSERT_V(size<0>(ASmemLayout{}) == size<0>(cta_tiler)); // BLK_M + CUTE_STATIC_ASSERT_V(size<0>(CSmemLayout{}) == size<0>(cta_tiler)); // BLK_M + CUTE_STATIC_ASSERT_V(size<0>(BSmemLayout{}) == size<1>(cta_tiler)); // BLK_N + CUTE_STATIC_ASSERT_V(size<1>(CSmemLayout{}) == size<1>(cta_tiler)); // BLK_N + CUTE_STATIC_ASSERT_V(size<1>(ASmemLayout{}) == size<2>(cta_tiler)); // BLK_K + CUTE_STATIC_ASSERT_V(size<1>(BSmemLayout{}) == size<2>(cta_tiler)); // BLK_K + + CUTE_STATIC_ASSERT_V(congruent(select<0,2>(shape_MNK), dA)); // dA strides for shape MK + CUTE_STATIC_ASSERT_V(congruent(select<1,2>(shape_MNK), dB)); // dB strides for shape NK + CUTE_STATIC_ASSERT_V(congruent(select<0,1>(shape_MNK), dC)); // dC strides for shape MN + + // + // Full and Tiled Tensors + // + + // Represent the full tensors + Tensor mA = make_tensor(make_gmem_ptr(A), select<0,2>(shape_MNK), dA); // (M,K) + Tensor mB = make_tensor(make_gmem_ptr(B), select<1,2>(shape_MNK), dB); // (N,K) + Tensor mC = make_tensor(make_gmem_ptr(C), select<0,1>(shape_MNK), dC); // (M,N) + + // Get the appropriate blocks for this thread block + auto cta_coord = make_coord(blockIdx.x, blockIdx.y, _); // (m,n,k) + Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,k) + Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) + Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1,_1, X>{}); // (BLK_M,BLK_N) + + // Shared memory buffers + __shared__ TA smemA[cosize_v]; + __shared__ TB smemB[cosize_v]; + Tensor sA = make_tensor(make_smem_ptr(smemA), sA_layout); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(smemB), sB_layout); // (BLK_N,BLK_K,PIPE) + + // + // Partition the copying of A and B tiles across the threads + // + + ThrCopy thr_copy_a = copy_a.get_slice(threadIdx.x); + Tensor tAgA = thr_copy_a.partition_S(gA); // (CPY,CPY_M,CPY_K,k) + Tensor tAsA = thr_copy_a.partition_D(sA); // (CPY,CPY_M,CPY_K,PIPE) + + ThrCopy thr_copy_b = copy_b.get_slice(threadIdx.x); + Tensor tBgB = thr_copy_b.partition_S(gB); // (CPY,CPY_N,CPY_K,k) + Tensor tBsB = thr_copy_b.partition_D(sB); // (CPY,CPY_N,CPY_K,PIPE) + + CUTE_STATIC_ASSERT_V(size<1>(tAgA) == size<1>(tAsA)); // CPY_M + CUTE_STATIC_ASSERT_V(size<2>(tAgA) == size<2>(tAsA)); // CPY_K + CUTE_STATIC_ASSERT_V(size<1>(tBgB) == size<1>(tBsB)); // CPY_N + CUTE_STATIC_ASSERT_V(size<2>(tBgB) == size<2>(tBsB)); // CPY_K + + // + // PREFETCH + // + + auto K_PIPE_MAX = size<3>(tAsA); + + // Total count of tiles + int k_tile_count = size<3>(tAgA); + // Current tile index in gmem to read from + int k_tile_next = 0; + + // Start async loads for all pipes but the last + CUTE_UNROLL + for (int k_pipe = 0; k_pipe < K_PIPE_MAX-1; ++k_pipe) { + copy(copy_a, tAgA(_,_,_,k_tile_next), tAsA(_,_,_,k_pipe)); + copy(copy_b, tBgB(_,_,_,k_tile_next), tBsB(_,_,_,k_pipe)); + cp_async_fence(); + --k_tile_count; + if (k_tile_count > 0) { ++k_tile_next; } + } + + // + // Define A/B partitioning and C accumulators + // + + ThrMMA thr_mma = mma.get_slice(threadIdx.x); + Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tCgC = thr_mma.partition_C(gC); // (MMA,MMA_M,MMA_N) + + // Allocate registers for pipelining + Tensor tCrA = thr_mma.make_fragment_A(tCsA(_,_,_,0)); // (MMA,MMA_M,MMA_K) + Tensor tCrB = thr_mma.make_fragment_B(tCsB(_,_,_,0)); // (MMA,MMA_N,MMA_K) + // Allocate the accumulators -- same size as the projected data + Tensor tCrC = thr_mma.make_fragment_C(tCgC); // (MMA,MMA_M,MMA_N) + + CUTE_STATIC_ASSERT_V(( shape(tCrA) == take<0,3>(shape(tCsA)))); // (MMA,MMA_M,MMA_K) + CUTE_STATIC_ASSERT_V(( shape(tCrB) == take<0,3>(shape(tCsB)))); // (MMA,MMA_N,MMA_K) + CUTE_STATIC_ASSERT_V(( shape(tCrC) == take<0,3>(shape(tCgC)))); // (MMA,MMA_M,MMA_N) + CUTE_STATIC_ASSERT_V((size<1>(tCgC) == size<1>(tCsA))); // MMA_M + CUTE_STATIC_ASSERT_V((size<2>(tCgC) == size<1>(tCsB))); // MMA_N + CUTE_STATIC_ASSERT_V((size<2>(tCsA) == size<2>(tCsB))); // MMA_K + + // Clear the accumulators + clear(tCrC); + +#if 0 + if(thread0()) { + print(" mA : "); print( mA); print("\n"); + print(" gA : "); print( gA); print("\n"); + print(" sA : "); print( sA); print("\n"); + print("tAgA : "); print(tAgA); print("\n"); + print("tAsA : "); print(tAsA); print("\n"); + } +#endif + +#if 0 + if(thread0()) { + print(" mB : "); print( mB); print("\n"); + print(" gB : "); print( gB); print("\n"); + print(" sB : "); print( sB); print("\n"); + print("tBgB : "); print(tBgB); print("\n"); + print("tBsB : "); print(tBsB); print("\n"); + } +#endif + +#if 0 + if(thread0()) { + print(" mC : "); print( mC); print("\n"); + print(" gC : "); print( gC); print("\n"); + print("tCsA : "); print(tCsA); print("\n"); + print("tCsB : "); print(tCsB); print("\n"); + print("tCgC : "); print(tCgC); print("\n"); + print("tCrA : "); print(tCrA); print("\n"); + print("tCrB : "); print(tCrB); print("\n"); + print("tCrC : "); print(tCrC); print("\n"); + } +#endif + +#if 1 + + // Current pipe index in smem to read from + int smem_pipe_read = 0; + // Current pipe index in smem to write to + int smem_pipe_write = K_PIPE_MAX-1; + + // Pipe slice + Tensor tCsA_p = tCsA(_,_,_,smem_pipe_read); + Tensor tCsB_p = tCsB(_,_,_,smem_pipe_read); + + // Size of the register pipeline + auto K_BLOCK_MAX = size<2>(tCrA); + + // PREFETCH register pipeline + if (K_BLOCK_MAX > 1) { + // Wait until our first prefetched tile is loaded in + cp_async_wait(); + __syncthreads(); + + // Prefetch the first rmem from the first k-tile + copy(tCsA_p(_,_,Int<0>{}), tCrA(_,_,Int<0>{})); + copy(tCsB_p(_,_,Int<0>{}), tCrB(_,_,Int<0>{})); + } + + // + // PIPELINED MAIN LOOP + // TUTORIAL: Example of a gemm loop that pipelines shared memory using SM80's cp.async instructions + // and explicit pipelines in shared memory. + // Data is read from global(k_tile_next) to shared(smem_pipe_write). + // Data is read from shared(smem_pipe_read) to registers(k_block_next). + // Data is computed on registers(b_block). + // + // This allows all copies and compute to overlap: + // Copy from gmem->smem can overlap with copies from smem->rmem and compute on rmem. + // Copy from smem->rmem can overlap with compute on rmem. + // + + CUTE_NO_UNROLL + while (k_tile_count > -(K_PIPE_MAX-1)) + { + CUTE_UNROLL + for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) + { + if (k_block == K_BLOCK_MAX - 1) + { + // Slice the smem_pipe_read smem + tCsA_p = tCsA(_,_,_,smem_pipe_read); + tCsB_p = tCsB(_,_,_,smem_pipe_read); + + // Commit the smem for smem_pipe_read + cp_async_wait(); + __syncthreads(); + } + + // Load A, B shmem->regs for k_block+1 + auto k_block_next = (k_block + Int<1>{}) % K_BLOCK_MAX; // static + copy(tCsA_p(_,_,k_block_next), tCrA(_,_,k_block_next)); + copy(tCsB_p(_,_,k_block_next), tCrB(_,_,k_block_next)); + // Copy gmem to smem before computing gemm on each k-pipe + if (k_block == 0) + { + copy(copy_a, tAgA(_,_,_,k_tile_next), tAsA(_,_,_,smem_pipe_write)); + copy(copy_b, tBgB(_,_,_,k_tile_next), tBsB(_,_,_,smem_pipe_write)); + cp_async_fence(); + + // Advance the gmem tile + --k_tile_count; + if (k_tile_count > 0) { ++k_tile_next; } + + // Advance the smem pipe + smem_pipe_write = smem_pipe_read; + ++smem_pipe_read; + smem_pipe_read = (smem_pipe_read == K_PIPE_MAX) ? 0 : smem_pipe_read; + } + // Thread-level register gemm for k_block + gemm(mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); + } + + } + +#endif + + // + // Epilogue + // + + axpby(alpha, tCrC, beta, tCgC); +} + +// Setup params for a NT GEMM +template +void +gemm_nt(int m, int n, int k, + Alpha alpha, + TA const* A, int ldA, + TB const* B, int ldB, + Beta beta, + TC * C, int ldC, + cudaStream_t stream = 0) +{ + using namespace cute; + + // Define shapes (dynamic) + auto M = int(m); + auto N = int(n); + auto K = int(k); + auto prob_shape = make_shape(M, N, K); // (M, N, K) + + // Define NT strides (mixed) + auto dA = make_stride(Int<1>{}, ldA); // (dM, dK) + auto dB = make_stride(Int<1>{}, ldB); // (dN, dK) + auto dC = make_stride(Int<1>{}, ldC); // (dM, dN) + + // Define CTA tile sizes (static) + auto bM = Int<128>{}; + auto bN = Int<128>{}; + auto bK = Int< 8>{}; + auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K) + auto bP = Int<3>{}; // Pipeline + + // Define the smem layouts (static) + auto sA = make_layout(make_shape(bM, bK, bP)); // (m,k,p) -> smem_idx; m-major + auto sB = make_layout(make_shape(bN, bK, bP)); // (n,k,p) -> smem_idx; n-major + auto sC = make_layout(make_shape(bM, bN)); // (m,n) -> smem_idx; m-major + + // Define the thread layouts (static) + + TiledCopy copyA = make_tiled_copy(Copy_Atom, TA>{}, + Layout>{}, // Thr layout 32x8 m-major + Layout>{});// Val layout 4x1 m-major + TiledCopy copyB = make_tiled_copy(Copy_Atom, TB>{}, + Layout>{}, // Thr layout 32x8 n-major + Layout>{});// Val layout 4x1 n-major + + TiledMMA mmaC = make_tiled_mma(UniversalFMA{}, + Layout>{}); // 16x16x1 TiledMMA + +#if 0 + print(copyA); + print(copyB); + print(mmaC); +#endif + +#if 0 + print_latex(copyA); + print_latex(copyB); + print_latex(mmaC); +#endif + + dim3 dimBlock(size(mmaC)); + dim3 dimGrid(size(ceil_div(M, bM)), + size(ceil_div(N, bN))); + gemm_device<<>> + (prob_shape, cta_tiler, + A, dA, sA, copyA, + B, dB, sB, copyB, + C, dC, sC, mmaC, + alpha, beta); +} + +// Setup params for a TN GEMM +template +void +gemm_tn(int m, int n, int k, + Alpha alpha, + TA const* A, int ldA, + TB const* B, int ldB, + Beta beta, + TC * C, int ldC, + cudaStream_t stream = 0) +{ + using namespace cute; + + // Define shapes (dynamic) + auto M = int(m); + auto N = int(n); + auto K = int(k); + auto prob_shape = make_shape(M, N, K); // (M, N, K) + + // Define TN strides (mixed) + auto dA = make_stride(ldA, Int<1>{}); // (dM, dK) + auto dB = make_stride(ldB, Int<1>{}); // (dN, dK) + auto dC = make_stride(Int<1>{}, ldC); // (dM, dN) + + // Define CTA tile sizes (static) + auto bM = Int<128>{}; + auto bN = Int<128>{}; + auto bK = Int< 8>{}; + auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K) + auto bP = Int<3>{}; // Pipeline + + // Define the smem layouts (static) + auto sA_atom = make_layout(make_shape ( bM, bK), + make_stride(Int<1>{}, bM+Int<1>{})); // (m,k) -> smem_idx; padded m-major + [[maybe_unused]] auto sB_atom = make_layout(make_shape ( bN, bK), + make_stride(Int<1>{}, bN+Int<1>{})); // (n,k) -> smem_idx; padded n-major + auto sA = tile_to_shape(sA_atom, make_shape(bM, bK, bP)); + auto sB = tile_to_shape(sA_atom, make_shape(bN, bK, bP)); + auto sC = make_layout(make_shape(bM, bN)); // (m,n) -> smem_idx + + // Define the thread layouts (static) + + TiledCopy copyA = make_tiled_copy(Copy_Atom, TA>{}, + Layout,Stride<_8,_1>>{}, // Thr layout 32x8 k-major + Layout>{}); // Val layout 1x1 + TiledCopy copyB = make_tiled_copy(Copy_Atom, TB>{}, + Layout,Stride<_8,_1>>{}, // Thr layout 32x8 k-major + Layout>{}); // Val layout 1x1 + + TiledMMA mmaC = make_tiled_mma(UniversalFMA{}, + Layout>{}); // 16x16x1 TiledMMA + +#if 0 + print(copyA); + print(copyB); + print(mmaC); +#endif + +#if 0 + print_latex(copyA); + print_latex(copyB); + print_latex(mmaC); +#endif + + dim3 dimBlock(size(mmaC)); + dim3 dimGrid(size(ceil_div(M, bM)), + size(ceil_div(N, bN))); + gemm_device<<>> + (prob_shape, cta_tiler, + A, dA, sA, copyA, + B, dB, sB, copyB, + C, dC, sC, mmaC, + alpha, beta); +} + +template +void +gemm(char transA, char transB, int m, int n, int k, + Alpha alpha, + TA const* A, int ldA, + TB const* B, int ldB, + Beta beta, + TC * C, int ldC, + cudaStream_t stream = 0) +{ + if (transA == 'N' && transB == 'T') { + return gemm_nt(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC, stream); + } else + if (transA == 'T' && transB == 'N') { + return gemm_tn(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC, stream); + } + assert(false && "Not implemented"); +} + + +int main(int argc, char** argv) +{ + cudaDeviceProp props; + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if (props.major < 8) { + std::cout << "This example requires an Ampere GPU or newer (CC >= 80)" << std::endl; + // Return 0 so tests pass if run on unsupported architectures or CUDA Toolkits. + return 0; + } + + int m = 5120; + if (argc >= 2) + sscanf(argv[1], "%d", &m); + + int n = 5120; + if (argc >= 3) + sscanf(argv[2], "%d", &n); + + int k = 4096; + if (argc >= 4) + sscanf(argv[3], "%d", &k); + + char transA = 'N'; + if (argc >= 5) + sscanf(argv[4], "%c", &transA); + + char transB = 'T'; + if (argc >= 6) + sscanf(argv[5], "%c", &transB); + + using TA = float; + using TB = float; + using TC = float; + using TI = float; + + TI alpha = 1.0; + TI beta = 0.0; + + std::cout << "M = " << m << std::endl; + std::cout << "N = " << n << std::endl; + std::cout << "K = " << k << std::endl; + std::cout << "C = A^" << transA << " B^" << transB << std::endl; + + thrust::host_vector h_A(m*k); + thrust::host_vector h_B(n*k); + thrust::host_vector h_C(m*n); + + for (int j = 0; j < m*k; ++j) h_A[j] = static_cast( 2*(rand() / double(RAND_MAX)) - 1 ); + for (int j = 0; j < n*k; ++j) h_B[j] = static_cast( 2*(rand() / double(RAND_MAX)) - 1 ); + for (int j = 0; j < m*n; ++j) h_C[j] = static_cast(-1); + + thrust::device_vector d_A = h_A; + thrust::device_vector d_B = h_B; + thrust::device_vector d_C = h_C; + + double gflops = (2.0*m*n*k) * 1e-9; + + const int timing_iterations = 100; + GPU_Clock timer; + + int ldA = 0, ldB = 0, ldC = m; + + if (transA == 'N') { + ldA = m; + } else if (transA == 'T') { + ldA = k; + } else { + assert(false); + } + + if (transB == 'N') { + ldB = k; + } else if (transB == 'T') { + ldB = n; + } else { + assert(false); + } + + // Run once + d_C = h_C; + gemm(transA, transB, m, n, k, + alpha, + d_A.data().get(), ldA, + d_B.data().get(), ldB, + beta, + d_C.data().get(), ldC); + CUTE_CHECK_LAST(); + thrust::host_vector cute_result = d_C; + + // Timing iterations + timer.start(); + for (int i = 0; i < timing_iterations; ++i) { + gemm(transA, transB, m, n, k, + alpha, + d_A.data().get(), ldA, + d_B.data().get(), ldB, + beta, + d_C.data().get(), ldC); + } + double cute_time = timer.seconds() / timing_iterations; + CUTE_CHECK_LAST(); + printf("CUTE_GEMM: [%6.1f]GFlop/s (%6.4f)ms\n", gflops / cute_time, cute_time*1000); + + return 0; +} diff --git a/examples/cute/tutorial/tiled_copy.cu b/examples/cute/tutorial/tiled_copy.cu new file mode 100644 index 0000000000..a8ae3b1040 --- /dev/null +++ b/examples/cute/tutorial/tiled_copy.cu @@ -0,0 +1,256 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#include +#include + +#include + +#include "cutlass/util/print_error.hpp" +#include "cutlass/util/GPU_Clock.hpp" +#include "cutlass/util/helper_cuda.hpp" + +// This is a simple tutorial showing several ways to partition a tensor into tiles then +// perform efficient, coalesced copies. This example also shows how to vectorize accesses +// which may be a useful optimization or required for certain workloads. +// +// `copy_kernel()` and `copy_kernel_vectorized()` each assume a pair of tensors with +// dimensions (m, n) have been partitioned via `tiled_divide()`. +// +// The result are a part of compatible tensors with dimensions ((M, N), m', n'), where +// (M, N) denotes a statically sized tile, and m' and n' denote the number of such tiles +// within the tensor. +// +// Each statically sized tile is mapped to a CUDA threadblock which performs efficient +// loads and stores to Global Memory. +// +// `copy_kernel()` uses `cute::local_partition()` to partition the tensor and map +// the result to threads using a striped indexing scheme. Threads themselve are arranged +// in a (ThreadShape_M, ThreadShape_N) arrangement which is replicated over the tile. +// +// `copy_kernel_vectorized()` uses `cute::make_tiled_copy()` to perform a similar +// partitioning using `cute::Copy_Atom` to perform vectorization. The actual vector +// size is defined by `ThreadShape`. +// +// This example assumes the overall tensor shape is divisible by the tile size and +// does not perform predication. + + +/// Simple copy kernel. +// +// Uses local_partition() to partition a tile among threads arranged as (THR_M, THR_N). +template +__global__ void copy_kernel(TensorS S, TensorD D, ThreadLayout) +{ + using namespace cute; + + // Slice the tiled tensors + Tensor tile_S = S(make_coord(_,_), blockIdx.x, blockIdx.y); // (BlockShape_M, BlockShape_N) + Tensor tile_D = D(make_coord(_,_), blockIdx.x, blockIdx.y); // (BlockShape_M, BlockShape_N) + + // Construct a partitioning of the tile among threads with the given thread arrangement. + + // Concept: Tensor ThrLayout ThrIndex + Tensor thr_tile_S = local_partition(tile_S, ThreadLayout{}, threadIdx.x); // (ThrValM, ThrValN) + Tensor thr_tile_D = local_partition(tile_D, ThreadLayout{}, threadIdx.x); // (ThrValM, ThrValN) + + // Construct a register-backed Tensor with the same shape as each thread's partition + // Use make_tensor to try to match the layout of thr_tile_S + Tensor fragment = make_tensor_like(thr_tile_S); // (ThrValM, ThrValN) + + // Copy from GMEM to RMEM and from RMEM to GMEM + copy(thr_tile_S, fragment); + copy(fragment, thr_tile_D); +} + +/// Vectorized copy kernel. +/// +/// Uses `make_tiled_copy()` to perform a copy using vector instructions. This operation +/// has the precondition that pointers are aligned to the vector size. +/// +template +__global__ void copy_kernel_vectorized(TensorS S, TensorD D, Tiled_Copy tiled_copy) +{ + using namespace cute; + + // Slice the tensors to obtain a view into each tile. + Tensor tile_S = S(make_coord(_, _), blockIdx.x, blockIdx.y); // (BlockShape_M, BlockShape_N) + Tensor tile_D = D(make_coord(_, _), blockIdx.x, blockIdx.y); // (BlockShape_M, BlockShape_N) + + // Construct a Tensor corresponding to each thread's slice. + ThrCopy thr_copy = tiled_copy.get_thread_slice(threadIdx.x); + + Tensor thr_tile_S = thr_copy.partition_S(tile_S); // (CopyOp, CopyM, CopyN) + Tensor thr_tile_D = thr_copy.partition_D(tile_D); // (CopyOp, CopyM, CopyN) + + // Construct a register-backed Tensor with the same shape as each thread's partition + // Use make_fragment because the first mode is the instruction-local mode + Tensor fragment = make_fragment_like(thr_tile_D); // (CopyOp, CopyM, CopyN) + + // Copy from GMEM to RMEM and from RMEM to GMEM + copy(tiled_copy, thr_tile_S, fragment); + copy(tiled_copy, fragment, thr_tile_D); +} + +/// Main function +int main(int argc, char** argv) +{ + // + // Given a 2D shape, perform an efficient copy + // + + using namespace cute; + using Element = float; + + // Define a tensor shape with dynamic extents (m, n) + auto tensor_shape = make_shape(256, 512); + + // + // Allocate and initialize + // + + thrust::host_vector h_S(size(tensor_shape)); + thrust::host_vector h_D(size(tensor_shape)); + + for (size_t i = 0; i < h_S.size(); ++i) { + h_S[i] = static_cast(i); + h_D[i] = Element{}; + } + + thrust::device_vector d_S = h_S; + thrust::device_vector d_D = h_D; + + // + // Make tensors + // + + Tensor tensor_S = make_tensor(make_gmem_ptr(thrust::raw_pointer_cast(d_S.data())), make_layout(tensor_shape)); + Tensor tensor_D = make_tensor(make_gmem_ptr(thrust::raw_pointer_cast(d_D.data())), make_layout(tensor_shape)); + + // + // Tile tensors + // + + // Define a statically sized block (M, N). + // Note, by convention, capital letters are used to represent static modes. + auto block_shape = make_shape(Int<128>{}, Int<64>{}); + + if ((size<0>(tensor_shape) % size<0>(block_shape)) || (size<1>(tensor_shape) % size<1>(block_shape))) { + std::cerr << "The tensor shape must be divisible by the block shape." << std::endl; + return -1; + } + // Equivalent check to the above + if (not evenly_divides(tensor_shape, block_shape)) { + std::cerr << "Expected the block_shape to evenly divide the tensor shape." << std::endl; + return -1; + } + + // Tile the tensor (m, n) ==> ((M, N), m', n') where (M, N) is the static tile + // shape, and modes (m', n') correspond to the number of tiles. + // + // These will be used to determine the CUDA kernel grid dimensions. + Tensor tiled_tensor_S = tiled_divide(tensor_S, block_shape); // ((M, N), m', n') + Tensor tiled_tensor_D = tiled_divide(tensor_D, block_shape); // ((M, N), m', n') + + // Construct a TiledCopy with a specific access pattern. + // This version uses a + // (1) Layout-of-Threads to describe the number and arrangement of threads (e.g. row-major, col-major, etc), + // (2) Layout-of-Values that each thread will access. + + // Thread arrangement + Layout thr_layout = make_layout(make_shape(Int<32>{}, Int<8>{})); // (32,8) -> thr_idx + + // Value arrangement per thread + Layout val_layout = make_layout(make_shape(Int<4>{}, Int<1>{})); // (4,1) -> val_idx + + // Define `AccessType` which controls the size of the actual memory access instruction. + using CopyOp = UniversalCopy>; // A very specific access width copy instruction + //using CopyOp = UniversalCopy>; // A more generic type that supports many copy strategies + //using CopyOp = AutoVectorizingCopy; // An adaptable-width instruction that assumes maximal alignment of inputs + + // A Copy_Atom corresponds to one CopyOperation applied to Tensors of type Element. + using Atom = Copy_Atom; + + // Construct tiled copy, a tiling of copy atoms. + // + // Note, this assumes the vector and thread layouts are aligned with contigous data + // in GMEM. Alternative thread layouts are possible but may result in uncoalesced + // reads. Alternative value layouts are also possible, though incompatible layouts + // will result in compile time errors. + TiledCopy tiled_copy = make_tiled_copy(Atom{}, // Access strategy + thr_layout, // thread layout (e.g. 32x4 Col-Major) + val_layout); // value layout (e.g. 4x1) + + // + // Determine grid and block dimensions + // + + dim3 gridDim (size<1>(tiled_tensor_D), size<2>(tiled_tensor_D)); // Grid shape corresponds to modes m' and n' + dim3 blockDim(size(thr_layout)); + + // + // Launch the kernel + // + copy_kernel_vectorized<<< gridDim, blockDim >>>( + tiled_tensor_S, + tiled_tensor_D, + tiled_copy); + + cudaError result = cudaDeviceSynchronize(); + if (result != cudaSuccess) { + std::cerr << "CUDA Runtime error: " << cudaGetErrorString(result) << std::endl; + return -1; + } + + // + // Verify + // + + h_D = d_D; + + int32_t errors = 0; + int32_t const kErrorLimit = 10; + + for (size_t i = 0; i < h_D.size(); ++i) { + if (h_S[i] != h_D[i]) { + std::cerr << "Error. S[" << i << "]: " << h_S[i] << ", D[" << i << "]: " << h_D[i] << std::endl; + + if (++errors >= kErrorLimit) { + std::cerr << "Aborting on " << kErrorLimit << "nth error." << std::endl; + return -1; + } + } + } + + std::cout << "Success." << std::endl; + + return 0; +} + diff --git a/examples/cute/tutorial/wgmma_sm90.cu b/examples/cute/tutorial/wgmma_sm90.cu new file mode 100644 index 0000000000..0baa494a37 --- /dev/null +++ b/examples/cute/tutorial/wgmma_sm90.cu @@ -0,0 +1,562 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#include +#include +#include + +#include +#include + +#include + +#include "cutlass/cluster_launch.hpp" +#include "cutlass/arch/barrier.h" +#include "cutlass/pipeline/sm90_pipeline.hpp" + +#include "cutlass/util/print_error.hpp" +#include "cutlass/util/GPU_Clock.hpp" +#include "cutlass/util/helper_cuda.hpp" +#include "cutlass/arch/mma_sm90.h" +#include "cutlass/device_kernel.h" + +using namespace cute; + +template // (N,K,P) +struct SharedStorage +{ + array_aligned> smem_A; + array_aligned> smem_B; + + uint64_t tma_barrier[size<2>(SmemLayoutA{})]; + uint64_t mma_barrier[size<2>(SmemLayoutA{})]; +}; + +template +__global__ static +__launch_bounds__(decltype(size(TiledMma{}))::value) +void +gemm_device(ProblemShape shape_MNK, CtaTiler cta_tiler, + TA const* A, CUTLASS_GRID_CONSTANT TmaA const tma_a, + TB const* B, CUTLASS_GRID_CONSTANT TmaB const tma_b, + TC * C, CStride dC, TiledMma mma, + Alpha alpha, Beta beta) +{ + // Preconditions + CUTE_STATIC_ASSERT_V(rank(shape_MNK) == Int<3>{}); // (M, N, K) + CUTE_STATIC_ASSERT_V(rank(cta_tiler) == Int<3>{}); // (BLK_M, BLK_N, BLK_K) + + static_assert(is_static::value); + static_assert(is_static::value); + + CUTE_STATIC_ASSERT_V(size<0>(SmemLayoutA{}) == size<0>(cta_tiler)); // BLK_M + CUTE_STATIC_ASSERT_V(size<0>(SmemLayoutB{}) == size<1>(cta_tiler)); // BLK_N + CUTE_STATIC_ASSERT_V(size<1>(SmemLayoutA{}) == size<2>(cta_tiler)); // BLK_K + CUTE_STATIC_ASSERT_V(size<1>(SmemLayoutB{}) == size<2>(cta_tiler)); // BLK_K + + CUTE_STATIC_ASSERT_V(congruent(select<0,1>(shape_MNK), dC)); // dC strides for shape MN + + // + // Full and Tiled Tensors + // + + // Represent the full tensors + auto [M, N, K] = shape_MNK; + Tensor mA = tma_a.get_tma_tensor(make_shape(M,K)); // (M,K) TMA Tensor + Tensor mB = tma_b.get_tma_tensor(make_shape(N,K)); // (N,K) TMA Tensor + Tensor mC = make_tensor(make_gmem_ptr(C), make_shape(M,N), dC); // (M,N) + + // Get the appropriate blocks for this thread block + auto cta_coord = make_coord(blockIdx.x, blockIdx.y, _); // (m,n,k) + Tensor gA = local_tile(mA, cta_tiler, cta_coord, Step<_1, X,_1>{}); // (BLK_M,BLK_K,k) + Tensor gB = local_tile(mB, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (BLK_N,BLK_K,k) + Tensor gC = local_tile(mC, cta_tiler, cta_coord, Step<_1,_1, X>{}); // (BLK_M,BLK_N) + + // Shared memory tensors + extern __shared__ char shared_memory[]; + using SharedStorage = SharedStorage; + SharedStorage& smem = *reinterpret_cast(shared_memory); + Tensor sA = make_tensor(make_smem_ptr(smem.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(smem.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // + // Partition the copying of A and B tiles + // + // TUTORIAL: + // These are TMA partitionings, which have a dedicated custom partitioner. + // The Int<0>, Layout<_1> indicates that the TMAs are not multicasted. + // Any multicasting must be in conformance with tma_x constructed with make_tma_atom on host. + // The group_modes<0,2> transforms the (X,Y,Z)-shaped tensors into ((X,Y),Z)-shaped tensors + // with the understanding that the TMA is responsible for everything in mode-0. + // The tma_partition reorders and offsets mode-0 according to the tma_x atom and the multicast info. + // + + auto [tAgA, tAsA] = tma_partition(tma_a, Int<0>{}, Layout<_1>{}, + group_modes<0,2>(sA), group_modes<0,2>(gA)); // (TMA,k) and (TMA,PIPE) + + auto [tBgB, tBsB] = tma_partition(tma_b, Int<0>{}, Layout<_1>{}, + group_modes<0,2>(sB), group_modes<0,2>(gB)); // (TMA,k) and (TMA,PIPE) + + // The TMA is responsible for copying everything in mode-0 of tAsA and tBsB + constexpr int kTmaTransactionBytes = CUTE_STATIC_V(size<0>(tAsA)) * sizeof(TA) + + CUTE_STATIC_V(size<0>(tBsB)) * sizeof(TB); + + // + // PREFETCH + // + + auto K_PIPE_MAX = size<1>(tAsA); + + // Total count of tiles + int k_tile_count = size<1>(tAgA); + // Current tile index in gmem to read from + int k_tile = 0; + + // Initialize Barriers + int warp_idx = cutlass::canonical_warp_idx_sync(); + int lane_predicate = cute::elect_one_sync(); + uint64_t* producer_mbar = smem.tma_barrier; + uint64_t* consumer_mbar = smem.mma_barrier; + + using ProducerBarType = cutlass::arch::ClusterTransactionBarrier; // TMA + using ConsumerBarType = cutlass::arch::ClusterBarrier; // MMA + CUTE_UNROLL + for (int pipe = 0; pipe < K_PIPE_MAX; ++pipe) { + if ((warp_idx == 0) && lane_predicate) { + ProducerBarType::init(&producer_mbar[pipe], 1); + ConsumerBarType::init(&consumer_mbar[pipe], 128); + } + } + // Ensure barrier init is complete on all CTAs + cluster_sync(); + + // Start async loads for all pipes + CUTE_UNROLL + for (int pipe = 0; pipe < K_PIPE_MAX; ++pipe) + { + if ((warp_idx == 0) && lane_predicate) + { + // Set expected Tx Bytes after each reset / init + ProducerBarType::arrive_and_expect_tx(&producer_mbar[pipe], kTmaTransactionBytes); + copy(tma_a.with(producer_mbar[pipe]), tAgA(_,k_tile), tAsA(_,pipe)); + copy(tma_b.with(producer_mbar[pipe]), tBgB(_,k_tile), tBsB(_,pipe)); + } + --k_tile_count; + ++k_tile; + } + + // + // Define A/B partitioning and C accumulators + // + // TUTORIAL: + // The tCrA and tCrB are actually Tensors of MMA Descriptors constructed as views of SMEM. + // The MMA Descriptor generation is automatic via inspection and validation of the SMEM Layouts. + // Because the MMA reads directly from SMEM and the fragments are descriptors rather than registers, + // there is no need for copy(tCsA, tCrA) in the mainloop. + // + + ThrMMA thr_mma = mma.get_thread_slice(threadIdx.x); + Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + Tensor tCgC = thr_mma.partition_C(gC); // (MMA,MMA_M,MMA_N) + + // Allocate accumulators and clear them + Tensor tCrC = thr_mma.make_fragment_C(tCgC); // (MMA,MMA_M,MMA_N) + clear(tCrC); + + // Allocate "fragments" + Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + + // + // PIPELINED MAIN LOOP + // + // TUTORIAL: + // Rather than interleaving the stages and instructions like in SM70 and SM80, + // the SM90 mainloops rely on explicit producer-consumer synchronization + // on the purely async instructions TMA and MMA. + // More advanced pipeline and warp-specialization strategies are available in CUTLASS mainloops. + // + + // A PipelineState is a circular pipe index [.index()] and a pipe phase [.phase()] + // that flips each cycle through K_PIPE_MAX. + auto write_state = cutlass::PipelineState(); // TMA writes + auto read_state = cutlass::PipelineState(); // MMA reads + + CUTE_NO_UNROLL + while (k_tile_count > -K_PIPE_MAX) + { + // Wait for Producer to complete + int read_pipe = read_state.index(); + ProducerBarType::wait(&producer_mbar[read_pipe], read_state.phase()); + + // MMAs to cover 1 K_TILE + warpgroup_arrive(); + gemm(mma, tCrA(_,_,_,read_pipe), tCrB(_,_,_,read_pipe), tCrC); // (V,M) x (V,N) => (V,M,N) + warpgroup_commit_batch(); + + // Wait for all MMAs in a K_TILE to complete + warpgroup_wait<0>(); + + // Notify that consumption is done + ConsumerBarType::arrive(&consumer_mbar[read_pipe]); + ++read_state; + + if ((warp_idx == 0) && lane_predicate) + { + int pipe = write_state.index(); + // Wait for Consumer to complete consumption + ConsumerBarType::wait(&consumer_mbar[pipe], write_state.phase()); + // Set expected Tx Bytes after each reset / init + ProducerBarType::arrive_and_expect_tx(&producer_mbar[pipe], kTmaTransactionBytes); + copy(tma_a.with(producer_mbar[pipe]), tAgA(_,k_tile), tAsA(_,pipe)); + copy(tma_b.with(producer_mbar[pipe]), tBgB(_,k_tile), tBsB(_,pipe)); + ++write_state; + } + --k_tile_count; + ++k_tile; + } + + // + // Epilogue (unpredicated) + // + + axpby(alpha, tCrC, beta, tCgC); +} + +// Setup params for an NT GEMM +template +void +gemm_nt(int m, int n, int k, + Alpha alpha, + TA const* A, int ldA, + TB const* B, int ldB, + Beta beta, + TC * C, int ldC, + cudaStream_t stream = 0) +{ + // Define shapes (dynamic) + auto M = int(m); + auto N = int(n); + auto K = int(k); + auto prob_shape = make_shape(M, N, K); // (M, N, K) + + // Define TN strides (mixed) + auto dA = make_stride(Int<1>{}, ldA); // (dM, dK) + auto dB = make_stride(Int<1>{}, ldB); // (dN, dK) + auto dC = make_stride(Int<1>{}, ldC); // (dM, dN) + + // Define CTA tile sizes (static) + auto bM = Int<128>{}; + auto bN = Int<128>{}; + auto bK = Int< 64>{}; + auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K) + auto bP = Int< 3>{}; // Pipeline + + // Define the smem layouts (static) + auto sA = tile_to_shape(GMMA::Layout_MN_SW128_Atom{}, make_shape(bM,bK,bP)); + auto sB = tile_to_shape(GMMA::Layout_MN_SW128_Atom{}, make_shape(bN,bK,bP)); + + // Define the MMA + TiledMMA tiled_mma = make_tiled_mma(SM90_64x64x16_F16F16F16_SS{}); + + // Define the TMAs + // Create Global memory tensors for TMA inspection + Tensor mA = make_tensor(A, make_shape(M,K), dA); + Tensor mB = make_tensor(B, make_shape(N,K), dB); + + // Create TMA Atoms with the desired copy operation on the source and destination + Copy_Atom tmaA = make_tma_atom(SM90_TMA_LOAD{}, mA, sA(_,_,0), make_shape(bM,bK)); + Copy_Atom tmaB = make_tma_atom(SM90_TMA_LOAD{}, mB, sB(_,_,0), make_shape(bN,bK)); + + // + // Setup and Launch + // + + // Launch parameter setup + int smem_size = int(sizeof(SharedStorage)); + dim3 dimBlock(size(tiled_mma)); + dim3 dimCluster(2, 1, 1); + dim3 dimGrid(round_up(size(ceil_div(m, bM)), dimCluster.x), + round_up(size(ceil_div(n, bN)), dimCluster.y)); + cutlass::ClusterLaunchParams params = {dimGrid, dimBlock, dimCluster, smem_size}; + + void const* kernel_ptr = reinterpret_cast( + &gemm_device); + + CUTE_CHECK_ERROR(cudaFuncSetAttribute( + kernel_ptr, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size)); + + // Kernel Launch + cutlass::Status status = cutlass::launch_kernel_on_cluster(params, kernel_ptr, + prob_shape, cta_tiler, + A, tmaA, + B, tmaB, + C, dC, tiled_mma, + alpha, beta); + CUTE_CHECK_LAST(); + + if (status != cutlass::Status::kSuccess) { + std::cerr << "Error: Failed at kernel Launch" << std::endl; + } +} + +// Setup params for a TN GEMM +template +void +gemm_tn(int m, int n, int k, + Alpha alpha, + TA const* A, int ldA, + TB const* B, int ldB, + Beta beta, + TC * C, int ldC, + cudaStream_t stream = 0) +{ + // Define shapes (dynamic) + auto M = int(m); + auto N = int(n); + auto K = int(k); + auto prob_shape = make_shape(M, N, K); // (M, N, K) + + // Define TN strides (mixed) + auto dA = make_stride(ldA, Int<1>{}); // (dM, dK) + auto dB = make_stride(ldB, Int<1>{}); // (dN, dK) + auto dC = make_stride(Int<1>{}, ldC); // (dM, dN) + + // Define CTA tile sizes (static) + auto bM = Int<128>{}; + auto bN = Int<128>{}; + auto bK = Int< 64>{}; + auto cta_tiler = make_shape(bM, bN, bK); // (BLK_M, BLK_N, BLK_K) + auto bP = Int<3>{}; // Pipeline + + // Define the smem layouts (static) + auto sA = tile_to_shape(GMMA::Layout_K_SW128_Atom{}, make_shape(bM,bK,bP)); + auto sB = tile_to_shape(GMMA::Layout_K_SW128_Atom{}, make_shape(bN,bK,bP)); + + // Define the MMA + TiledMMA tiled_mma = make_tiled_mma(SM90_64x64x16_F16F16F16_SS{}); + + // Define the TMAs + // Create Global memory tensors for TMA inspection + Tensor mA = make_tensor(A, make_shape(M,K), dA); + Tensor mB = make_tensor(B, make_shape(N,K), dB); + + // Create TMA Atoms with the desired copy operation on the source and destination + Copy_Atom tmaA = make_tma_atom(SM90_TMA_LOAD{}, mA, sA(_,_,0), make_shape(bM,bK)); + Copy_Atom tmaB = make_tma_atom(SM90_TMA_LOAD{}, mB, sB(_,_,0), make_shape(bN,bK)); + + // + // Setup and Launch + // + + // Launch parameter setup + int smem_size = int(sizeof(SharedStorage)); + dim3 dimBlock(size(tiled_mma)); + dim3 dimCluster(2, 1, 1); + dim3 dimGrid(round_up(size(ceil_div(m, bM)), dimCluster.x), + round_up(size(ceil_div(n, bN)), dimCluster.y)); + cutlass::ClusterLaunchParams params = {dimGrid, dimBlock, dimCluster, smem_size}; + + void const* kernel_ptr = reinterpret_cast( + &gemm_device); + + CUTE_CHECK_ERROR(cudaFuncSetAttribute( + kernel_ptr, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size)); + + // Kernel Launch + cutlass::Status status = cutlass::launch_kernel_on_cluster(params, kernel_ptr, + prob_shape, cta_tiler, + A, tmaA, + B, tmaB, + C, dC, tiled_mma, + alpha, beta); + CUTE_CHECK_LAST(); + + if (status != cutlass::Status::kSuccess) { + std::cerr << "Error: Failed at kernel Launch" << std::endl; + } +} + +template +void +gemm(char transA, char transB, int m, int n, int k, + Alpha alpha, + TA const* A, int ldA, + TB const* B, int ldB, + Beta beta, + TC * C, int ldC, + cudaStream_t stream = 0) +{ + if (transA == 'N' && transB == 'T') { + return gemm_nt(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC, stream); + } else + if (transA == 'T' && transB == 'N') { + return gemm_tn(m, n, k, alpha, A, ldA, B, ldB, beta, C, ldC, stream); + } + assert(false && "Not implemented"); +} + +int main(int argc, char** argv) +{ + + cudaDeviceProp props; + cudaError_t error = cudaGetDeviceProperties(&props, 0); + if (error != cudaSuccess) { + std::cerr << "cudaGetDeviceProperties() returned an error: " << cudaGetErrorString(error) << std::endl; + return -1; + } + + if (props.major != 9) { + std::cout << "This example requires NVIDIA's Hopper Architecture GPU with compute capability 90a\n" << std::endl; + return 0; + } + +#if defined(CUTLASS_ARCH_MMA_SM90_SUPPORTED) + + int m = 512; + if (argc >= 2) + sscanf(argv[1], "%d", &m); + + int n = 256; + if (argc >= 3) + sscanf(argv[2], "%d", &n); + + int k = 1024; + if (argc >= 4) + sscanf(argv[3], "%d", &k); + + char transA = 'N'; + if (argc >= 5) + sscanf(argv[4], "%c", &transA); + + char transB = 'T'; + if (argc >= 6) + sscanf(argv[5], "%c", &transB); + + using TA = cute::half_t; + using TB = cute::half_t; + using TC = cute::half_t; + using TI = cute::half_t; + + TI alpha = TI(1.0f); + TI beta = TI(0.0f); + + thrust::host_vector h_A(m*k); + thrust::host_vector h_B(n*k); + thrust::host_vector h_C(m*n); + + // Initialize the tensors + for (int j = 0; j < m*k; ++j) h_A[j] = TA(int((rand() % 2) ? 1 : -1)); + for (int j = 0; j < n*k; ++j) h_B[j] = TB(int((rand() % 2) ? 1 : -1)); + for (int j = 0; j < m*n; ++j) h_C[j] = TC(0); + + thrust::device_vector d_A = h_A; + thrust::device_vector d_B = h_B; + thrust::device_vector d_C = h_C; + + double gflops = (2.0*m*n*k) * 1e-9; + + const int timing_iterations = 100; + GPU_Clock timer; + + int ldA = 0, ldB = 0, ldC = m; + + if (transA == 'N') { + ldA = m; + } else if (transA == 'T') { + ldA = k; + } else { + assert(false); + } + + if (transB == 'N') { + ldB = k; + } else if (transB == 'T') { + ldB = n; + } else { + assert(false); + } + + // Run once + d_C = h_C; + gemm(transA, transB, m, n, k, + alpha, + d_A.data().get(), ldA, + d_B.data().get(), ldB, + beta, + d_C.data().get(), ldC); + CUTE_CHECK_LAST(); + thrust::host_vector cute_result = d_C; + + // Timing iterations + timer.start(); + for (int i = 0; i < timing_iterations; ++i) { + gemm(transA, transB, m, n, k, + alpha, + d_A.data().get(), ldA, + d_B.data().get(), ldB, + beta, + d_C.data().get(), ldC); + } + double cute_time = timer.seconds() / timing_iterations; + CUTE_CHECK_LAST(); + printf("CUTE_GEMM: [%6.1f]GFlop/s (%6.4f)ms\n", gflops / cute_time, cute_time*1000); + +#else + + std::cout << "CUTLASS_ARCH_MMA_SM90_SUPPORTED must be enabled, but it is not. Test is waived \n" << std::endl; +#endif + + return 0; + +} diff --git a/examples/python/00_basic_gemm.ipynb b/examples/python/00_basic_gemm.ipynb new file mode 100644 index 0000000000..c27955517e --- /dev/null +++ b/examples/python/00_basic_gemm.ipynb @@ -0,0 +1,475 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "id": "1ef96b3f", + "metadata": {}, + "source": [ + "# Basic example of using the CUTLASS Python interface\n", + "This notebook walks through a basic example of using the CUTLASS Python interface to declare, compile, and run GEMMs.\n", + "\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cutlass/blob/main/examples/python/00_basic_gemm.ipynb)\n" + ] + }, + { + "cell_type": "markdown", + "id": "df94d7e6", + "metadata": {}, + "source": [ + "## Prerequisites for running on Colab\n", + "This notebook requires an NVIDIA GPU. If `nvidia-smi` fails, go to Runtime -> Change runtime type -> Hardware accelerator and confirm a GPU is selected." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "71c7a069", + "metadata": {}, + "outputs": [], + "source": [ + "!#nvidia-smi" + ] + }, + { + "cell_type": "markdown", + "id": "cf16785d", + "metadata": {}, + "source": [ + "If running on Colab, you will need to install the CUTLASS Python interface. To do so, uncomment the following line and run the cell:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c819bb68", + "metadata": {}, + "outputs": [], + "source": [ + "!#pip install nvidia-cutlass" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "962324fd", + "metadata": {}, + "source": [ + "## General setup\n", + "We first import various packages needed for the example and construct the input and output tensors that will be used in our example." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0e324219", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "import random\n", + "\n", + "import cutlass\n", + "\n", + "# This controls whether the C++ GEMM declaration will be printed at each step. \n", + "# Set to `False` to omit this information.\n", + "print_module = True\n", + "\n", + "m = 128\n", + "n = m\n", + "k = m\n", + "\n", + "dtype = np.float16\n", + "type_A = np.float16\n", + "type_B = np.float16\n", + "type_C = np.float16\n", + "type_D = np.float16\n", + "\n", + "np.random.seed(1234)\n", + "random.seed(1234)\n", + "scope_min = -4\n", + "scope_max = 4\n", + "tensor_A = np.ceil(np.random.uniform(low=scope_min, high=scope_max, size=(m, k)).astype(type_A))\n", + "tensor_B = np.ceil(np.random.uniform(low=scope_min, high=scope_max, size=(k, n)).astype(type_B))\n", + "tensor_C = np.ceil(np.random.uniform(low=scope_min, high=scope_max, size=(m, n)).astype(type_C))\n", + "\n", + "alpha = np.float16(1.)\n", + "beta = np.float16(0.)\n", + "\n", + "tensor_D = np.zeros(tensor_C.shape).astype(type_D)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "f2c7bf48", + "metadata": {}, + "source": [ + "## Declaring and running a GEMM\n", + "To get started, one only needs to provide the tensors declared above to the `cutlass.op.Gemm` call.\n", + "This sets up a default GEMM operation for the given device on which you are running.\n", + "\n", + "Assuming that we are running on SM80, this default to using a GEMM that leverages FP16 Tensor Core operations.\n", + "\n", + "Calling `plan.run()` will generate the CUTLASS C++ kernel in question, compile it, and run it on the tensors we previously passed in. By setting `print_module` to `true`, the C++ code that is emitted is printed." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0dfd8975", + "metadata": {}, + "outputs": [], + "source": [ + "# We specify `element_accumulator` here so as to match the kernel run by NumPy below. However,\n", + "# specifying `element_accumulator` is not required if it is the same as `element`\n", + "plan = cutlass.Gemm(element=dtype, layout=cutlass.LayoutType.RowMajor, element_accumulator=np.float32)\n", + "plan.run(tensor_A, tensor_B, tensor_C, tensor_D, print_module=print_module)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "4a5856de", + "metadata": {}, + "source": [ + "There are many other ways to construct a plan from `cutlass.op.Gemm` (e.g., by specifiying they types and layouts of each operand, by providing representative tensors as inputs). For more details on these, see the documentation in the `cutlass.op.Gemm` constructor." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "945478ef", + "metadata": {}, + "source": [ + "We then compare the output to running the GEMM using NumPy." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6b669de6", + "metadata": {}, + "outputs": [], + "source": [ + "tensor_D_numpy = (alpha * (tensor_A @ tensor_B)) + (beta * tensor_C)\n", + "np.testing.assert_array_equal(tensor_D, tensor_D_numpy)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "ee5cbbbe", + "metadata": {}, + "source": [ + "Note that one could use the same kernel just declared for tensors provided by other frameworks beyond NumPy, such as PyTorch or CuPy." + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "b6c86493", + "metadata": {}, + "source": [ + "## Changing operation modes\n", + "By default, the CUTLASS Python interface will try to use Tensor Core operations whenever possible. If the configuration provided to `cutlass.op.Gemm` is not supported on Tensor Cores, the interface will fall back to using a SIMT kernel.\n", + "\n", + "The operation mode currently in use can be returned via the `plan.opclass` property. In this case Tensor Core operations." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "529fda93", + "metadata": {}, + "outputs": [], + "source": [ + "print(plan.opclass)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "6d27c575", + "metadata": {}, + "source": [ + "Suppose that we don't want to use Tensor Cores for this GEMM. One can change to using CUTLASS's SIMT GEMMs by setting the plan's `opclass` field.\n", + "\n", + "As is shown in the printed output, the emitted kernel uses template parameters that fit CUTLASS's SIMT GEMMs.\n", + "\n", + "Also notice that, this time around, we provided tensor parameters to `plan.run()`. One is free to provide different parameters to `plan.run()` than were passed in at the initial call to `cutlass.op.Gemm`, provided that the passed-in tensors have the same data type and layout as those passed in on intialization." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "6a44d35b", + "metadata": {}, + "outputs": [], + "source": [ + "tensor_D_simt = np.zeros(tensor_C.shape).astype(type_D)\n", + "plan.opclass = cutlass.OpcodeClass.Simt\n", + "plan.run(tensor_A, tensor_B, tensor_C, tensor_D_simt, alpha, beta, print_module=print_module)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "639dcb59", + "metadata": {}, + "source": [ + "If we compare the output of the Tensor Core and SIMT GEMMs we just ran we see that they are equal." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9b480853", + "metadata": {}, + "outputs": [], + "source": [ + "np.testing.assert_array_equal(tensor_D, tensor_D_simt)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "0cce1eae", + "metadata": {}, + "source": [ + "## Running cached kernels\n", + "You may have noticed that the `plan.run()` calls for the previous two kernels took some time to execute. This is because the kernel being emitted had not yet been compiled.\n", + "\n", + "CUTLASS caches compiled binaries so that recompilation isn't necessary every time a kernel is run. For example, if we change modes back to using Tensor Cores and call `plan.run()` again (with a different set of tensor parameters), you'll find the call to return much faster." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "f8051e5e", + "metadata": {}, + "outputs": [], + "source": [ + "m = 2400\n", + "n = 3232\n", + "k = 4096\n", + "\n", + "tensor_A = np.ceil(np.random.uniform(low=scope_min, high=scope_max, size=(m, k)).astype(type_A))\n", + "tensor_B = np.ceil(np.random.uniform(low=scope_min, high=scope_max, size=(k, n)).astype(type_B))\n", + "tensor_C = np.ceil(np.random.uniform(low=scope_min, high=scope_max, size=(m, n)).astype(type_C))\n", + "tensor_D = np.zeros(tensor_C.shape).astype(type_D)\n", + "\n", + "alpha = np.float16(1.)\n", + "beta = np.float16(2.)\n", + "\n", + "plan.opclass = cutlass.OpcodeClass.TensorOp\n", + "plan.run(tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, print_module=print_module)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "52a4e318", + "metadata": {}, + "source": [ + "## Running non-default GEMMs\n", + "The previous examples showed how it is simple to get started running a default GEMM kernel in CUTLASS. But, what do you do if you want a bit more control over the parameters to the GEMM?\n", + "\n", + "Under the hood, CUTLASS enumerates the different GEMM configuration parameters possible for this kernel from the CUTLASS profiler. The code below shows how one can access the tile descriptions for the kernels (e.g., cluster, threadblock, and warp shape)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1c593be1", + "metadata": {}, + "outputs": [], + "source": [ + "tiles = plan.tile_descriptions()\n", + "print('{} tile descriptions returned'.format(len(tiles)))\n", + "num_print = 10\n", + "print('First {} tile descriptions are:'.format(num_print))\n", + "for td in tiles[:num_print]:\n", + " print(td)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "dc3ad875", + "metadata": {}, + "source": [ + "Next, we'll pick one of these configurations at random and compile and run it." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a8dc5287", + "metadata": {}, + "outputs": [], + "source": [ + "tiles = [td for td in tiles if td.threadblock_shape[0] >= 128]\n", + "idx = random.randint(0, len(tiles)-1)\n", + "td = tiles[idx]\n", + "print('Tile description {} is: {}'.format(idx, td))\n", + "plan.compile(td)\n", + "plan.run(tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, print_module=print_module)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "c5a8b534", + "metadata": {}, + "source": [ + "One can also change the swizzling function used by the kernel. For example, one can modify the kernel to use the stream K feature of CUTLASS via:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e5e88d17", + "metadata": {}, + "outputs": [], + "source": [ + "# Stream K is exposed through the threadblock swizzle method for pre-SM90 kernels,\n", + "# and via the tile_scheduler attribute of the TileDescription for post-SM90 kernels\n", + "if plan.cc < 90:\n", + " plan.swizzling_functor = cutlass.swizzle.ThreadblockSwizzleStreamK\n", + " plan.run(tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, print_module=print_module)\n", + "else:\n", + " # Stream-K is currently only supported for warp-specialized cooperative kernels\n", + " td.kernel_schedule = cutlass.KernelScheduleType.TmaWarpSpecializedCooperative\n", + " td.epilogue_schedule = cutlass.EpilogueScheduleType.TmaWarpSpecializedCooperative\n", + " td.tile_scheduler = cutlass.TileSchedulerType.StreamK\n", + "\n", + " plan.compile(td)\n", + " plan.run(tensor_A, tensor_B, tensor_C, tensor_D, alpha, beta, print_module=print_module)" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "5a8ba2ba", + "metadata": {}, + "source": [ + "## Handling errors\n", + "The CUTLASS Python interface attempts to catch runtime and compilation errors in Python so as to provide more understandable error messages.\n", + "\n", + "Here's an example in which we try to use too many stages for a given GEMM kernel. Normally, this would result in a runtime error due to the GPU having insufficient shared memory to launch the kernel with 8 stages. The CUTLASS Python interface is able to detect this issue before compiling the kernel, and reports it back to the user. Uncomment and run the code below to see this error." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fe7d0e42", + "metadata": {}, + "outputs": [], + "source": [ + "# td = tiles[0]\n", + "# td.stages = 8\n", + "# plan.compile(td)" + ] + }, + { + "cell_type": "markdown", + "id": "0fff34a4", + "metadata": {}, + "source": [ + "## Specializations for other data types\n", + "\n", + "Various CUTLASS kernels specialized for specific data types can also be run via the Python interface.\n", + "\n", + "For example, the code below shows how to declare and run a GEMM using the 3xTF32 feature (see corresponding C++ example [here](https://github.com/NVIDIA/cutlass/blob/main/examples/27_ampere_3xtf32_fast_accurate_tensorop_gemm/27_ampere_3xtf32_fast_accurate_tensorop_gemm.cu))." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "338ad890", + "metadata": {}, + "outputs": [], + "source": [ + "from cutlass.backend.utils.device import device_cc\n", + "\n", + "# 3xTF32 requires SM80 or higher\n", + "if device_cc() >= 80:\n", + " plan = cutlass.op.Gemm(element=np.float32, layout=cutlass.LayoutType.RowMajor)\n", + " plan.math_operation = cutlass.MathOperation.multiply_add_fast_f32\n", + "\n", + " # Create input/output tensors in FP32\n", + " A, B = [np.ones((128, 128)).astype(np.float32) for _ in range(2)]\n", + " C, D = [np.zeros((128, 128)).astype(np.float32) for _ in range(2)]\n", + "\n", + " # Run the GEMM\n", + " plan.run(A, B, C, D, print_module=print_module)" + ] + }, + { + "cell_type": "markdown", + "id": "65531df1", + "metadata": {}, + "source": [ + "Additionally, one can run CUTLASS's FP8 GEMMs if using a frontend library capable of allocating and initializing FP8 tensors (e.g., PyTorch)" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "776f1d8d", + "metadata": {}, + "outputs": [], + "source": [ + "try:\n", + " import torch\n", + "except ImportError:\n", + " print(\"PyTorch is not available. Skipping FP8 example\")\n", + " import sys; sys.exit(0)\n", + "\n", + "if not hasattr(torch, \"float8_e4m3fn\"):\n", + " print(\"Version of PyTorch does not have the float8_e4m3fn data type. Skipping FP8 example\")\n", + " import sys; sys.exit(0)\n", + "\n", + "# FP8 is supported through the CUTLASS Python interface on SM90 and higher\n", + "if device_cc() >= 90:\n", + " plan = cutlass.op.Gemm(element=torch.float8_e4m3fn, element_C=torch.float32, element_accumulator=torch.float32,\n", + " layout_A=cutlass.LayoutType.RowMajor, layout_B=cutlass.LayoutType.ColumnMajor,\n", + " layout_C=cutlass.LayoutType.ColumnMajor)\n", + "\n", + " # Create input/output tensors in FP8\n", + " A, B = [torch.ones((128, 128)).to(torch.float8_e4m3fn).to(\"cuda\") for _ in range(2)]\n", + " C, D = [torch.zeros((128, 128)).to(torch.float8_e4m3fn).to(\"cuda\") for _ in range(2)]\n", + "\n", + " # Run the GEMM\n", + " plan.run(A, B, C, D, print_module=print_module)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.6.9" + }, + "vscode": { + "interpreter": { + "hash": "0466d96796c9cd8f7a1cad264ff326ececc950ba2420e0256d5105fc1a3c6e70" + } + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/python/01_epilogue.ipynb b/examples/python/01_epilogue.ipynb new file mode 100644 index 0000000000..97663f5035 --- /dev/null +++ b/examples/python/01_epilogue.ipynb @@ -0,0 +1,253 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "id": "5d24a692", + "metadata": {}, + "source": [ + "# Example of using elementwise activation functions in the CUTLASS Python interface\n", + "This notebook walks through a basic example of using the CUTLASS Python interface to declare, compile, and run GEMMs with different epilogues.\n", + "\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cutlass/blob/main/examples/python/01_epilogue.ipynb)\n" + ] + }, + { + "cell_type": "markdown", + "id": "28c916da", + "metadata": {}, + "source": [ + "## Prerequisites for running on Colab\n", + "This notebook requires an NVIDIA GPU. If `nvidia-smi` fails, go to Runtime -> Change runtime type -> Hardware accelerator and confirm a GPU is selected." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "0fcea8ea", + "metadata": {}, + "outputs": [], + "source": [ + "!#nvidia-smi" + ] + }, + { + "cell_type": "markdown", + "id": "7ec60b57", + "metadata": {}, + "source": [ + "If running on Colab, you will need to install the CUTLASS Python interface. To do so, uncomment the following line and run the cell:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "1db9e51c", + "metadata": {}, + "outputs": [], + "source": [ + "!#pip install nvidia-cutlass" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "962324fd", + "metadata": {}, + "source": [ + "## General setup\n", + "We first import various packages needed for the example and construct the input and output tensors that will be used in our example." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "63a70a3c", + "metadata": {}, + "outputs": [], + "source": [ + "import numpy as np\n", + "\n", + "import cutlass\n", + "\n", + "# This controls whether ther C++ GEMM declaration will be printed at each step. Set to `false` to\n", + "# omit this information.\n", + "print_module = True\n", + "\n", + "m = 256\n", + "n = m\n", + "k = m\n", + "\n", + "type_A = np.float16\n", + "type_B = np.float16\n", + "type_C = np.float16\n", + "type_D = np.float16\n", + "\n", + "np.random.seed(1234)\n", + "scope_min = -4\n", + "scope_max = 4\n", + "tensor_A = np.ceil(np.random.uniform(low=scope_min, high=scope_max, size=(m, k)).astype(type_A))\n", + "tensor_B = np.ceil(np.random.uniform(low=scope_min, high=scope_max, size=(k, n)).astype(type_B))\n", + "tensor_C = np.ceil(np.random.uniform(low=scope_min, high=scope_max, size=(m, n)).astype(type_C))\n", + "\n", + "alpha = np.float16(1.)\n", + "beta = np.float16(0.)\n", + "\n", + "tensor_D = np.zeros(tensor_C.shape).astype(type_D)" + ] + }, + { + "cell_type": "markdown", + "id": "1eb0d95b", + "metadata": {}, + "source": [ + "## Run a GEMM with an identity activation function\n", + "To begin, we simply run a default GEMM with an identity activation function. This performs the well-known operation `D = alpha * (A @ B) + beta * C`. This is the default activation function used, and does not need to be specified." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8d257833", + "metadata": {}, + "outputs": [], + "source": [ + "plan = cutlass.op.Gemm(element=np.float16, layout=cutlass.LayoutType.RowMajor)\n", + "plan.run(tensor_A, tensor_B, tensor_C, tensor_D, print_module=print_module)" + ] + }, + { + "cell_type": "markdown", + "id": "54961694", + "metadata": {}, + "source": [ + "## Run a GEMM with a ReLU element-wise activation function\n", + "CUTLASS makes it easy to support other element-wise activation functions. This results in performing an element-wise after the generic linear combination performed in a GEMM. If we call such an activation function `act`, the resulting formulation is:\n", + "```\n", + "D = alpha * (A @ B) + beta * C\n", + "D = act(D)\n", + "```\n", + "\n", + "Here, we will add a ReLU activation function. Given an input `x`, ReLU returns `max(x, 0)`.\n", + "\n", + "This is easy to do in CUTLASS. One only needs to set the plan's `activation` field." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5fe49443", + "metadata": {}, + "outputs": [], + "source": [ + "tensor_D_relu = np.zeros(tensor_C.shape).astype(type_D)\n", + "plan.activation = \"relu\"\n", + "plan.run(tensor_A, tensor_B, tensor_C, tensor_D_relu, print_module=print_module)" + ] + }, + { + "cell_type": "markdown", + "id": "455d0a37", + "metadata": {}, + "source": [ + "We can now verify that the result of the GEMM that used a ReLU activation function:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e32e7798", + "metadata": {}, + "outputs": [], + "source": [ + "relu_ref = (tensor_D >= 0).astype(type_D) * tensor_D\n", + "np.testing.assert_array_equal(relu_ref, tensor_D_relu)" + ] + }, + { + "cell_type": "markdown", + "id": "cf959171", + "metadata": {}, + "source": [ + "## Other element-wise activation functions\n", + "CUTLASS supports a variety of widely-used element-wise activation functions. We can obtain a list of these functions via the `get_activations()` method." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9e17d730", + "metadata": {}, + "outputs": [], + "source": [ + "activations = plan.activations()\n", + "for activation in activations:\n", + " print(activation)" + ] + }, + { + "cell_type": "markdown", + "id": "0e4599fa", + "metadata": {}, + "source": [ + "We can then run each of them:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9c3598c9", + "metadata": {}, + "outputs": [], + "source": [ + "for activation in activations:\n", + " print('=============================================================================================')\n", + " print(f'Compiling and running activation {activation}')\n", + " print('=============================================================================================')\n", + " plan.activation = activation\n", + " plan.run(tensor_A, tensor_B, tensor_C, tensor_D, print_module=print_module)" + ] + }, + { + "cell_type": "markdown", + "id": "18828622", + "metadata": {}, + "source": [ + "To add an activation with parameter such as `leaky_relu`, a tuple should be provided containing the activation function name and the (or a list of) parameter." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "53108eae", + "metadata": {}, + "outputs": [], + "source": [ + "negative_slope = 0.5\n", + "plan.activation = (\"leaky_relu\", negative_slope)\n", + "plan.run(tensor_A, tensor_B, tensor_C, tensor_D, print_module=print_module)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/python/02_pytorch_extension_grouped_gemm.ipynb b/examples/python/02_pytorch_extension_grouped_gemm.ipynb new file mode 100644 index 0000000000..86c86fb65c --- /dev/null +++ b/examples/python/02_pytorch_extension_grouped_gemm.ipynb @@ -0,0 +1,300 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "id": "6acbea5d", + "metadata": {}, + "source": [ + "# Exporting a CUTLASS grouped GEMM kernel to a PyTorch CUDA extension\n", + "This notebook walks through a basic example of using the CUTLASS Python interface to declare\n", + "a grouped GEMM kernel and export it as a PyTorch CUDA extension. Note that GEMM and Conv2d can also be exported as PyTorch CUDA extensions. \n", + "\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cutlass/blob/main/examples/python/02_pytorch_extension_grouped_gemm.ipynb)\n" + ] + }, + { + "cell_type": "markdown", + "id": "2d70560e", + "metadata": {}, + "source": [ + "## Prerequisites for running on Colab\n", + "This notebook requires an NVIDIA GPU. If `nvidia-smi` fails, go to Runtime -> Change runtime type -> Hardware accelerator and confirm a GPU is selected." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cc7c7458", + "metadata": {}, + "outputs": [], + "source": [ + "!#nvidia-smi" + ] + }, + { + "cell_type": "markdown", + "id": "2107bb0d", + "metadata": {}, + "source": [ + "If running on Colab, you will need to install the CUTLASS Python interface and PyTorch. To do so, uncomment the following line and run the cell:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "a9852cb8", + "metadata": {}, + "outputs": [], + "source": [ + "!#pip install nvidia-cutlass torch --extra-index-url https://download.pytorch.org/whl/cu121" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "962324fd", + "metadata": {}, + "source": [ + "## Background on grouped GEMM\n", + "Grouped GEMM enables one to execute a set of GEMMs (each with potentially different sizes and strides)\n", + "in a single CUDA kernel. It can be thought of as a generalized version of a pointer-array GEMM,\n", + "without the requirement that the sizes and strides of each GEMM be the same.\n", + "\n", + "For example, if one has `p` GEMMs with sizes:\n", + "```text\n", + "M_1 x N_1 x K_1\n", + "M_2 x N_2 x K_2\n", + "...\n", + "M_p x N_p x K_p\n", + "```\n", + "CUTLASS's grouped GEMM will execute these in a single CUDA kernel.\n", + "\n", + "Grouped GEMM is particularly beneficial for saturating the GPU with many small problems that would\n", + "insufficiently utilize the device in isolation.\n", + "\n", + "## Declaring a grouped GEMM via the CUTLASS Python interface\n", + "A grouped GEMM operation is declared similarly to a GEMM operation in the CUTLASS Python interface: one\n", + "simply calls `cutlass.op.GroupedGemm`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "fdcf21d8", + "metadata": {}, + "outputs": [], + "source": [ + "import cutlass\n", + "import torch\n", + "\n", + "dtype = torch.float16\n", + "plan = cutlass.op.GroupedGemm(element=dtype, layout=cutlass.LayoutType.RowMajor)" + ] + }, + { + "cell_type": "markdown", + "id": "514f40a4", + "metadata": {}, + "source": [ + "We can then compile and run this operation on a group of GEMMs. We'll first set up some utility functions to initialize GEMMs." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "c2a7371e", + "metadata": {}, + "outputs": [], + "source": [ + "import random\n", + "random.seed(2023)\n", + "\n", + "# Utility function to initialize A, B, C, and D matrices corresponding to dimensions M, N, and K\n", + "def initialize(dtype, M, N, K):\n", + " sizes = [(M, K), (K, N), (M, N), (M, N)]\n", + " return [torch.randint(-3, 3, size, device='cuda').to(dtype) for size in sizes]\n", + "\n", + "# Utility function to generate `problems` GEMMs of random sizes\n", + "def generate_problems(problems):\n", + " valid_sizes = [128, 256, 512, 1024]\n", + " As, Bs, Cs, Ds = [], [], [], []\n", + " for _ in range(problems):\n", + " M, N, K = [random.choice(valid_sizes) for _ in range(3)]\n", + " A, B, C, D = initialize(dtype, M, N, K)\n", + " As.append(A)\n", + " Bs.append(B)\n", + " Cs.append(C)\n", + " Ds.append(D)\n", + " return As, Bs, Cs, Ds" + ] + }, + { + "cell_type": "markdown", + "id": "590a3bc5", + "metadata": {}, + "source": [ + "We'll next run a group of 20 GEMMs via the CUTLASS Python interface and via PyTorch." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "776c9233", + "metadata": {}, + "outputs": [], + "source": [ + "As, Bs, Cs, Ds, = generate_problems(20)\n", + "\n", + "plan.run(As, Bs, Cs, Ds, print_module=True)\n", + "Ds_torch = [a @ b for a, b in zip(As, Bs)]\n", + "\n", + "for d, d_torch in zip(Ds, Ds_torch):\n", + " assert torch.allclose(d, d_torch)" + ] + }, + { + "cell_type": "markdown", + "id": "766e4f03", + "metadata": {}, + "source": [ + "## Exporting the CUTLASS kernel to a PyTorch CUDA extension\n", + "The procedure above allows one to quickly experiment with using a CUTLASS kernels However, one might prefer to use the CUTLASS kernel via a [PyTorch CUDA extension](https://pytorch.org/tutorials/advanced/cpp_extension.html). This will avoids adding any runtime overheads associated with the Python portions of the CUTLASS Python interface.\n", + "\n", + "The CUTLASS Python interface provides simple solutions for creating PyTorch CUDA extensions for a CUTLASS kernel. These extensions can either be written out for a later \"ahead-of-time\" compilation, or be just-in-time compiled and returned to the user.\n", + "\n", + "To create a JIT-compiled module from the CUTLASS kernel we defined above, simply call the following:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "3a98dee6", + "metadata": {}, + "outputs": [], + "source": [ + "op = plan.construct()\n", + "grouped_gemm = cutlass.emit.pytorch(op, name='grouped_gemm', cc=plan.cc, sourcedir='out', jit=True)" + ] + }, + { + "cell_type": "markdown", + "id": "c8ca3991", + "metadata": {}, + "source": [ + "The `cutlass.emit.pytorch` function emits:\n", + "* `out/grouped_gemm_kernel.cu`: This file contains the declaration of the CUTLASS kernel and a method to call it from PyTorch tensors\n", + "* `out/grouped_gemm.cpp`: This file contains a C++ wrapper around the aforementioned CUTLASS kernel\n", + "* `setup.py`: This file contains the `setuptools` script for building and installing the generated extension\n", + "\n", + "The extension can be build from within the `module_output` directory by running:\n", + "```bash\n", + "TORCH_CUDA_ARCH_LIST=\"8.0\" python setup.py install\n", + "```\n", + "Where `TORCH_ARCH_LIST` is set to the compute capability of the device on which the kernel will be run.\n", + "\n", + "See the PyTorch [\"Custom C++ and CUDA Extensions\"](https://pytorch.org/tutorials/advanced/cpp_extension.html) tutorial for more details on this.\n", + "\n", + "The PyTorch CUDA extension could be built for this module by running:\n", + "```bash\n", + "cd out\n", + "TORCH_CUDA_ARCH_LIST=\"8.0\" python setup.py\n", + "```\n", + "(assuming that one is building for SM80)\n", + "\n", + "One could then use the kernel in a later PyTorch module by running:\n", + "\n", + "```python\n", + "import torch\n", + "import grouped_gemm\n", + "\n", + "grouped_gemm.run(As, Bs)\n", + "```\n", + "\n", + "In this case, however, we set `jit=True`, which specifies that we would like to compile and load the PyTorch CUDA extension on the fly.\n", + "Under the hood, this leverages the [torch.utils.cpp_extension.load](https://pytorch.org/tutorials/advanced/cpp_extension.html) method\n", + "and returns back the loaded extension.\n", + "\n", + "We can then use the extension and compare its results to running the GEMMs via vanilla PyTorch GEMMs:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "cecb26a4", + "metadata": {}, + "outputs": [], + "source": [ + "Ds = grouped_gemm.run(As, Bs)\n", + "Ds_torch = [a @ b for a, b in zip(As, Bs)]\n", + "for d, d_torch in zip(Ds, Ds_torch):\n", + " assert torch.allclose(d, d_torch)" + ] + }, + { + "cell_type": "markdown", + "id": "50db80e4", + "metadata": {}, + "source": [ + "Finally, we can profile our grouped GEMM extension:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "b76805d3", + "metadata": {}, + "outputs": [], + "source": [ + "num_warmup = 20\n", + "num_profile = 100\n", + "\n", + "# Warmup iterations\n", + "for _ in range(num_warmup):\n", + " Ds = grouped_gemm.run(As, Bs)\n", + " Ds_torch = [a @ b for a, b in zip(As, Bs)]\n", + " torch.cuda.synchronize()\n", + "\n", + "# Timing iterations\n", + "import time\n", + "grouped = 0\n", + "nongrouped = 0\n", + "for _ in range(num_profile):\n", + " start = time.time()\n", + " Ds = grouped_gemm.run(As, Bs)\n", + " torch.cuda.synchronize()\n", + " grouped += time.time() - start\n", + "\n", + " start = time.time()\n", + " Ds_torch = [a @ b for a, b in zip(As, Bs)]\n", + " torch.cuda.synchronize()\n", + " nongrouped += time.time() - start\n", + "\n", + "print('Grouped: {:.3f} us'.format(grouped * 1e6/num_profile))\n", + "print('Non-Grouped: {:.3f} us'.format(nongrouped * 1e6/num_profile))\n", + "print('Speedup: {:.3f}'.format(nongrouped / grouped))" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/python/03_basic_conv2d.ipynb b/examples/python/03_basic_conv2d.ipynb new file mode 100644 index 0000000000..d0eb452675 --- /dev/null +++ b/examples/python/03_basic_conv2d.ipynb @@ -0,0 +1,465 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "# Basic example of using the CUTLASS Python interface for Conv2d\n", + "\n", + "This notebook walks through a basic example of using the CUTLASS Python interface to declare, compile, and run Conv2d. \n", + "\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cutlass/blob/main/examples/python/03_basic_conv2d.ipynb)\n" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Prerequisites for running on Colab\n", + "This notebook requires an NVIDIA GPU. If `nvidia-smi` fails, go to Runtime -> Change runtime type -> Hardware accelerator and confirm a GPU is selected." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!#nvidia-smi" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If running on Colab, you will need to install the CUTLASS Python interface. To do so, uncomment the following line and run the cell:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "!#pip install nvidia-cutlass" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## General setup\n", + "We first import various packages needed for the example and construct the input and output tensors that will be used in our example." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import random\n", + "\n", + "import cutlass\n", + "\n", + "# This controls whether the C++ GEMM declaration will be printed at each step. \n", + "# Set to `false` to omit this information.\n", + "print_module = True\n", + "\n", + "# Input tensor: [N, H, W, C] under the channel-last layout\n", + "N, H, W, C = [32, 28, 28, 64]\n", + "\n", + "# Weight tensor: [K, R, S, C] under the channel-last layout\n", + "K, R, S = [128, 3, 3]\n", + "\n", + "# Stride, and padding\n", + "stride = (2, 2)\n", + "padding = (1, 1)\n", + "dilation = (1, 1)\n", + "\n", + "# Compute the output size [N, P, Q, K]\n", + "N, P, Q, K = cutlass.Conv2d.output_size((N, H, W, C), (K, R, S, C), padding, stride, dilation)\n", + "\n", + "dtype = torch.float16\n", + "type_A = torch.float16\n", + "type_B = torch.float16\n", + "type_C = torch.float16\n", + "type_D = torch.float16\n", + "\n", + "torch.manual_seed(1234)\n", + "\n", + "input = torch.ceil(\n", + " torch.empty(size=(N, C, H, W), dtype=type_A, device=\"cuda\").uniform_(-4.5, 3.5)\n", + ").to(memory_format=torch.channels_last)\n", + "weight = torch.ceil(\n", + " torch.empty(size=(K, C, R, S), dtype=type_B, device=\"cuda\").uniform_(-4.5, 3.5)\n", + ").to(memory_format=torch.channels_last)\n", + "tensor_C = torch.ceil(\n", + " torch.empty(size=(N, K, P, Q), dtype=type_B, device=\"cuda\").uniform_(-4.5, 3.5)\n", + ").to(memory_format=torch.channels_last)\n", + "output = torch.zeros_like(tensor_C)\n", + "\n", + "alpha = 1.0\n", + "beta = 0.0" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Declaring and running a Conv2d Fprop\n", + "\n", + "We first show you how to run a Conv2d in the forward propagation. To get started, one only needs to provide the tensors declared above to the `cutlass.op.Conv2dFprop` call. This sets up a default Conv2d fprop operation for the given device on which you are running. \n", + "\n", + "Assuming that we are runing on SM80, the default is a Conv2d that leverages FP16 Tensor Core operations.\n", + "\n", + "Calling `plan.run()` will generate the CUTLASS C++ kernel in question, compile it, and run it on the tensors we previously passed in. By setting `print_module` to `true`, the C++ code that is emitted is printed." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Specifying `element_accumulator` is not required if it is the same as `element`\n", + "plan = cutlass.Conv2dFprop(element=dtype, element_accumulator=torch.float32)\n", + "plan.run(input, weight, tensor_C, output, stride, padding, dilation, alpha, beta, print_module=print_module)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "There are many other ways to construct a plan from `cutlass.op.Conv2dFprop` (e.g., by specifying the types of each operand, by providing representative tensors as input). For more details on these, see the documentation in the `cutlass.op.Conv2dFprop` constructor.\n", + "\n", + "We then compare the output to running the Conv2d using PyTorch. PyTorch use NCHW layout by default, so permutations are required." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "output_torch = alpha * torch.ops.aten.conv2d(\n", + " input, weight, stride=stride, padding=padding, dilation=dilation\n", + ") + beta * tensor_C\n", + "\n", + "assert torch.equal(output_torch, output)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Note that one could use the same kernel just declared for tensors provided by other frameworks beyond PyTorch, such as NumPy." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Declaring and running Conv2d Dgrad and Wgrad\n", + "\n", + "The Python interface also supports declaring and running backward kernels of Conv2d. To begin with, we construct the tensors for the gradient of input, output, and weight." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "grad_output = torch.ceil(\n", + " torch.empty(size=(N, K, P, Q), dtype=type_A, device=\"cuda\").uniform_(-4.5, 3.5)\n", + ").to(memory_format=torch.channels_last)\n", + "grad_input = torch.zeros_like(input)\n", + "grad_weight = torch.zeros_like(weight)\n", + "\n", + "tensor_C_dgrad = torch.ceil(\n", + " torch.empty(size=(N, C, H, W), dtype=type_A, device=\"cuda\").uniform_(-4.5, 3.5)\n", + ").to(memory_format=torch.channels_last)\n", + "tensor_C_wgrad = torch.ceil(\n", + " torch.empty(size=(K, C, R, S), dtype=type_B, device=\"cuda\").uniform_(-4.5, 3.5)\n", + ").to(memory_format=torch.channels_last)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The script below gives a simple example of computing a data gradient via the CUTLASS Python interface and via PyTorch." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plan_dgrad = cutlass.Conv2dDgrad(element=dtype, element_accumulator=torch.float32)\n", + "plan_dgrad.run(grad_output, weight, tensor_C_dgrad, grad_input, stride, padding, dilation, alpha, beta, print_module=print_module)\n", + "\n", + "grad_input_torch = alpha * torch.nn.grad.conv2d_input(\n", + " (N, C, H, W),\n", + " weight, grad_output,\n", + " stride=stride, padding=padding\n", + ") + beta * tensor_C_dgrad\n", + "\n", + "assert torch.equal(grad_input_torch, grad_input)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "The script below gives a simple example of computing a weight gradient via the CUTLASS Python interface and via PyTorch." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plan_wgrad = cutlass.Conv2dWgrad(element=dtype, element_accumulator=torch.float32)\n", + "plan_wgrad.run(grad_output, input, tensor_C_wgrad, grad_weight, stride, padding, dilation, alpha, beta, print_module=print_module)\n", + "\n", + "grad_weight_torch = alpha * torch.nn.grad.conv2d_weight(\n", + " input, (K, C, R, S), grad_output,\n", + " stride=stride, padding=padding\n", + ") + beta * tensor_C_wgrad\n", + "\n", + "assert torch.equal(grad_weight_torch, grad_weight)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Running non-default Conv2ds\n", + "\n", + "The previous examples showed how it is simple to get starting running a default Conv2d kernel in CUTLASS. But, what do you do if you want a bit more control over the parameters to the Conv2d? CUTLASS Python interface exposes mutable parameters that can be set after the `plan` initialization. We summarize these in the table below.\n", + "\n", + "|Parameter|Description|\n", + "| -- | -- |\n", + "|`tile_description`|The threadblock tile size, warp count, software pipeline stages, and instruction shape|\n", + "|`iterator_algorithm`|The iterator algorithm used to access the source operands|\n", + "|`swizzling_stride`|The stride of the threadblock swizzling functor|\n", + "|`split-K`|Partitions the reduction dimension to different threadblocks|" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Tile Description\n", + "\n", + "The `tile_description` defines the tiling size of each threadblock, the warp count along each dimension of the tile, the software pipeline stages, and the instruction size. Under the hood, CUTLASS enumerates the different Conv2d configuration parameters for this kernel from the CUTLASS profiler. The code below shows how one can access the tile descriptions for the kernel (e.g., threadblock and warp shape)." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plan.opclass = \"tensor_op\"\n", + "tiles = plan.tile_descriptions()\n", + "print(f'{len(tiles)} tile descriptions returned')\n", + "num_print = 10\n", + "print(f'First {num_print} tile descriptions are:')\n", + "for td in tiles[:num_print]:\n", + " print(td)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Next, we'll pick one of these configurations at random and compile and run it." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "random.seed(42)\n", + "idx = random.randint(0, len(tiles)-1)\n", + "td = tiles[idx]\n", + "print(f'Tile description {idx} is: {td}')\n", + "plan.tile_description = td\n", + "plan.run(input, weight, tensor_C, output, stride, padding, dilation, alpha, beta, print_module=print_module)\n", + "assert torch.equal(output_torch, output)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Besides tile descriptions enumerated by CUTLASS, the users can also explicitly set the `threadblockshape`, `warp_shape`, `stages`, `instruction_shape`, and `cluster_shape`. If the configuration is invalid, an exception will be raised at `plan.run()` and the detailed compilation error will be stored in `./cutlass_python_compilation_error.txt` for debugging." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "if plan.cc == 70:\n", + " plan.tile_description = {\n", + " \"threadblock_shape\": [64, 256, 32],\n", + " \"warp_count\": [1, 4, 1],\n", + " \"stages\": 2,\n", + " \"instruction_shape\": [8, 8, 4], # optional,\n", + " \"cluster_shape\": [1, 1, 1] # optional, only [1, 1, 1] is supported currently\n", + " }\n", + "elif plan.cc == 75:\n", + " plan.tile_description = {\n", + " \"threadblock_shape\": [128, 64, 32],\n", + " \"warp_count\": [2, 1, 1],\n", + " \"stages\": 2,\n", + " \"instruction_shape\": [16, 8, 8], # optional,\n", + " \"cluster_shape\": [1, 1, 1] # optional, only [1, 1, 1] is supported currently\n", + " }\n", + "elif plan.cc == 80:\n", + " plan.tile_description = {\n", + " \"threadblock_shape\": [128, 128, 64],\n", + " \"warp_count\": [2, 2, 1],\n", + " \"stages\": 4,\n", + " \"instruction_shape\": [16, 8, 16], # optional,\n", + " \"cluster_shape\": [1, 1, 1] # optional, only [1, 1, 1] is supported currently\n", + " }\n", + "elif plan.cc == 86:\n", + " plan.tile_description = {\n", + " \"threadblock_shape\": [128, 64, 64],\n", + " \"warp_count\": [2, 2, 1],\n", + " \"stages\": 3,\n", + " \"instruction_shape\": [16, 8, 16],\n", + " \"cluster_shape\": [1, 1, 1]\n", + " }\n", + "\n", + "plan.run(input, weight, tensor_C, output, stride, padding, dilation, alpha, beta, print_module=print_module)\n", + "assert torch.equal(output_torch, output)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Iterator Algorithm\n", + "\n", + "The iterator algorithm describes how sources are loaded from memory. There are some iterator algorithms optimized for specific alignments and input/output channels that have better performance. The table below illustrates the available iterator algorithms.\n", + "\n", + "|Conv Kind | Iterator Algorithm | Description |\n", + "| -- | -- | -- |\n", + "|Fprop | \"analytic\" | Functionally correct in all cases but lower performance |\n", + "| | \"optimized\" | Optimized for and requires `R <= 32`, `S<= 32`, and `C % alignment_input == 0`|\n", + "| | \"few_channels\" | optimized for small `C` and requires `C % alignment_input == 0`|\n", + "| | \"fixed_channels\" | optimized for small `C` and requires `C == alignment_input` |\n", + "|Dgrad | \"analytic\" | Functionally correct in all cases but lower performance |\n", + "| | \"optimized\" | Optimzed for and require `R <= 32`, `S<= 32`, `K % alignment_grad_output == 0`, and `C % alignment_weight == 0`|\n", + "|Wgrad | \"analytic\" | Functionally correct in all cases but lower performance |\n", + "| | \"optimized\" | Optimized for and require `K % alignment_grad_output == 0`, and `C % alignment_input == 0`|\n", + "\n", + "By default, the Python interface will automatically propose a suitable iterator algorithm based on the input tensors in `plan.run()`. However, the user can also specify the desired iterator algorithm as follows" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plan.iterator_algorithm = \"analytic\"\n", + "plan.run(input, weight, tensor_C, output, stride, padding, dilation, alpha, beta, print_module=print_module)\n", + "assert torch.equal(output_torch, output)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "If the iterator algorithm is invalid for the problem size in `plan.run()`, an exception will be raised." + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Swizzling Stride\n", + "The swizzling changes how the tile are mapped to threadblocks to improve the L2 Locality. Given a swizzling stride `N`, the threadblock `(tb_x, tb_y)` computes tile `(tb_x / N, tb_y * N + (tb_x % N))`. Currently, stride values of `1`, `2`, `4`, and `8` are supported for `fprop`, `wgrad`, and `1`, and `4` for `dgrad`. The swizzling stride can be set with:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "plan.swizzling_stride = 4\n", + "plan.run(input, weight, tensor_C, output, stride, padding, dilation, alpha, beta, print_module=print_module)\n", + "assert torch.equal(output_torch, output)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Split-K\n", + "Split-K is usually applied when the Conv2d has small spatial dimensions and large reduction dimension to ensure good utilization. It further partitions the reduction dimension to different threadblocks. The CUTLASS Python interface supports two types of split-K strategies: `Parallel`, and `Serial`. \n", + "* `Parallel`: the partial results from different threadblocks are stored in a temporary buffer in the global memory. When the Conv2d is done, a separate reduction kernel is created and launched to reduce the partial results.\n", + "* `Serial`: A semaphore is used to coordinate the order of different threadblocks adding their partial results to a given output tile. A separate kernel does not need to be launched for prforming the reduction.\n", + "\n", + "While all `fprop`, `dgrad`, and `wgrad` support split-K, here we use `wgrad` as an example. " + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Parallel Split-K with 5 slices\n", + "grad_weight_parallel = torch.zeros_like(grad_weight)\n", + "plan_wgrad.run(\n", + " grad_output, input, tensor_C_wgrad, grad_weight_parallel, \n", + " stride, padding, dilation, alpha, beta, print_module=print_module, split_k=(\"parallel\", 5))\n", + "assert torch.equal(grad_weight_torch, grad_weight_parallel)\n", + "\n", + "# Serial Split-K with 3 slices\n", + "grad_weight_serial = torch.zeros_like(grad_weight)\n", + "plan_wgrad.run(\n", + " grad_output, input, tensor_C_wgrad, grad_weight_serial, \n", + " stride, padding, dilation, alpha, beta, print_module=print_module, split_k=(\"serial\", 3))\n", + "assert torch.equal(grad_weight_torch, grad_weight_serial)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.10.9" + } + }, + "nbformat": 4, + "nbformat_minor": 4 +} diff --git a/examples/python/04_epilogue_visitor.ipynb b/examples/python/04_epilogue_visitor.ipynb new file mode 100644 index 0000000000..cf66cd2414 --- /dev/null +++ b/examples/python/04_epilogue_visitor.ipynb @@ -0,0 +1,258 @@ +{ + "cells": [ + { + "attachments": {}, + "cell_type": "markdown", + "id": "5d24a692", + "metadata": {}, + "source": [ + "# Example of using epilogue visitor in the CUTLASS Python interface\n", + "This notebook walks through a basic example of using the CUTLASS Python interface to declare, compile, and run GEMMs with different epilogues through CUTLASS Epilogue Visitor.\n", + "\n", + "[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/github/NVIDIA/cutlass/blob/main/examples/python/04_epilogue_visitor.ipynb)\n" + ] + }, + { + "cell_type": "markdown", + "id": "3a800e79", + "metadata": {}, + "source": [ + "## Prerequisites for running on Colab\n", + "This notebook requires an NVIDIA GPU. If `nvidia-smi` fails, go to Runtime -> Change runtime type -> Hardware accelerator and confirm a GPU is selected." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "9cfff2c8", + "metadata": {}, + "outputs": [], + "source": [ + "!#nvidia-smi" + ] + }, + { + "cell_type": "markdown", + "id": "06706f00", + "metadata": {}, + "source": [ + "If running on Colab, you will need to install the CUTLASS Python interface. To do so, uncomment the following line and run the cell:" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "491a7314", + "metadata": {}, + "outputs": [], + "source": [ + "!#pip install nvidia-cutlass" + ] + }, + { + "attachments": {}, + "cell_type": "markdown", + "id": "962324fd", + "metadata": {}, + "source": [ + "## General setup\n", + "We first import various packages needed for the example, construct the input and output tensors that will be used in our example." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "63a70a3c", + "metadata": {}, + "outputs": [], + "source": [ + "import torch\n", + "import cutlass\n", + "from cutlass.epilogue import relu\n", + "from cutlass import Tensor as FakeTensor\n", + "from cutlass.utils.profiler import CUDAEventProfiler\n", + "\n", + "# This controls whether ther C++ GEMM declaration will be printed at each step. Set to `false` to\n", + "# omit this information.\n", + "print_module = True\n", + "\n", + "# The Epilogue Visitor feature currently only works for SM80 and 90\n", + "from cutlass.backend.utils.device import device_cc\n", + "if device_cc() not in [80, 90]:\n", + " import sys\n", + " sys.exit()\n", + "\n", + "m = 16384\n", + "n = m\n", + "k = 512\n", + "\n", + "type_A = torch.float16\n", + "type_B = torch.float16\n", + "type_C = torch.float16\n", + "type_D = torch.float16\n", + "\n", + "torch.manual_seed(2023)\n", + "scope_min = -4\n", + "scope_max = 4\n", + "tensor_A = torch.ceil(torch.empty(size=(m, k), dtype=type_A, device=\"cuda\").uniform_(scope_min, scope_max))\n", + "tensor_B = torch.ceil(torch.empty(size=(k, n), dtype=type_B, device=\"cuda\").uniform_(scope_min, scope_max))\n", + "tensor_C = torch.ceil(torch.empty(size=(m, n), dtype=type_C, device=\"cuda\").uniform_(scope_min, scope_max))\n", + "tensor_D = torch.zeros_like(tensor_C)\n", + "\n", + "plan = cutlass.op.Gemm(element=torch.float16, layout=cutlass.LayoutType.RowMajor, element_accumulator=torch.float32)" + ] + }, + { + "cell_type": "markdown", + "id": "1eb0d95b", + "metadata": {}, + "source": [ + "## Define the epilogue visitor functor\n", + "The epilogue functor can be defined as a simple Python function and a set of example tensors for inputs and outputs. The example below illustrates a complex epilogue under the directed acyclic graph structure (`F` is used twice). The epilogue takes source tensors in different ranks: `alpha`, `beta` are scalars, `bias` is a column vector to broadcast, and `C`, `aux` are matrices. It contains various math operations from basic arithmatic operations and built-in callable functions like `relu`. It also accomodates multiple outputs `D` and `F`. Note that there are some restrictions on syntax.\n", + "* Each named variable must be assigned exactly once and defined before it used.\n", + "* Reserved names: `accum`, `C`, and `D` are reserved for accumulator, tensor_C, and tensor_D.\n", + "* Return values must be a named variable.\n", + "\n", + "The example tensors is a dictionary with tensor names as keys and reference tensors as values. The reference tensors can be `float`, `torch.Tensor`, `numpy.ndarray`, or our `FakeTensor`. They provides the shape and data type information of the inputs and outputs of the epilogue.\n", + "\n", + "The epilogue can be generated simply through `cutlass.evt.trace(, )`." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8d257833", + "metadata": {}, + "outputs": [], + "source": [ + "# Define epilogue visitor\n", + "def example_epilogue(accum, alpha, C, beta, aux, bias):\n", + " F = alpha * accum + (beta * C + aux)\n", + " E = relu(F + 1) + bias\n", + " D = E + F\n", + " return D, F\n", + "\n", + "# Construct inputs and outputs\n", + "alpha = 0.5\n", + "beta = 0.5\n", + "aux = torch.ceil(torch.empty(size=(m, n), dtype=type_C, device=\"cuda\").uniform_(scope_min, scope_max))\n", + "bias = torch.ceil(torch.empty(size=(m, 1), dtype=type_C, device=\"cuda\").uniform_(scope_min, scope_max))\n", + "tensor_F = torch.zeros_like(tensor_D)\n", + "examples_tensors = {\n", + " \"accum\": FakeTensor(element=torch.float32, shape=(m, n), layout_tag=cutlass.LayoutType.RowMajor),\n", + " \"alpha\": alpha,\n", + " \"C\": tensor_C,\n", + " \"beta\": beta,\n", + " \"aux\": aux,\n", + " \"bias\": bias,\n", + " \"D\": tensor_D,\n", + " \"F\": tensor_F\n", + "}\n", + "\n", + "# Trace the epilogue visitor\n", + "epilogue_visitor = cutlass.epilogue.trace(example_epilogue, examples_tensors)" + ] + }, + { + "cell_type": "markdown", + "id": "54961694", + "metadata": {}, + "source": [ + "## Run a GEMM with the epilogue visitor functor\n", + "The `epilogue_visitor` can be used by setting the plan's `epilogue_visitor` field. The arguments for the epilogue visitor are provided as a `dict` through the `visitor_args` keyword argument." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "5fe49443", + "metadata": {}, + "outputs": [], + "source": [ + "visitor_args = {\n", + " \"alpha\": alpha, \"C\": tensor_C, \"beta\": beta, \n", + " \"aux\": aux, \"bias\": bias, \"D\": tensor_D, \"F\": tensor_F\n", + "}\n", + "\n", + "plan.epilogue_visitor = epilogue_visitor\n", + "plan.run(\n", + " tensor_A, tensor_B, tensor_C, tensor_D, \n", + " visitor_args=visitor_args, print_module=print_module)" + ] + }, + { + "cell_type": "markdown", + "id": "455d0a37", + "metadata": {}, + "source": [ + "The epilogue function `example_epilogue` can be used as a reference function. We can now verify the results simply with" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "e32e7798", + "metadata": {}, + "outputs": [], + "source": [ + "class TorchReference(torch.nn.Module):\n", + " def forward(self, A, B, alpha, C, beta, aux, bias):\n", + " accum = torch.matmul(A, B)\n", + " return example_epilogue(accum, alpha, C, beta, aux, bias)\n", + "\n", + "torch_reference = TorchReference()\n", + "tensor_D_ref, tensor_F_ref = torch_reference(tensor_A, tensor_B, alpha, tensor_C, beta, aux, bias)\n", + "\n", + "assert torch.equal(tensor_D, tensor_D_ref)\n", + "assert torch.equal(tensor_F, tensor_F_ref)" + ] + }, + { + "cell_type": "markdown", + "id": "b69e441f", + "metadata": {}, + "source": [ + "The performance of CUTLASS fused kernel can be profiled with" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "8db92150", + "metadata": {}, + "outputs": [], + "source": [ + "warmup_iterations = 10\n", + "profile_iterations = 50\n", + "# Profile CUTLASS fused kernel\n", + "duration = CUDAEventProfiler(\n", + " plan, warmup_iterations, profile_iterations,\n", + " tensor_A, tensor_B, tensor_C, tensor_D, \n", + " visitor_args=visitor_args)()\n", + "\n", + "print(f\"CUTLASS duration: {duration:.2f} ms\")" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": "Python 3 (ipykernel)", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.8.10" + } + }, + "nbformat": 4, + "nbformat_minor": 5 +} diff --git a/examples/python/README.md b/examples/python/README.md new file mode 100644 index 0000000000..590f2e24e4 --- /dev/null +++ b/examples/python/README.md @@ -0,0 +1,22 @@ +# Examples of using the CUTLASS Python interface + +* [00_basic_gemm](/examples/python/00_basic_gemm.ipynb) + + Shows how declare, configure, compile, and run a CUTLASS GEMM using the Python interface + +* [01_epilogue](/examples/python/01_epilogue.ipynb) + + Shows how to fuse elementwise activation functions to GEMMs via the Python interface + +* [02_pytorch_extension_grouped_gemm](/examples/python/02_pytorch_extension_grouped_gemm.ipynb) + + Shows how to declare, compile, and run a grouped GEMM operation via the Python interface, + along with how the emitted kernel can be easily exported to a PyTorch CUDA extension. + +* [03_basic_conv2d](/examples/python/03_basic_conv2d.ipynb) + + Shows how to declare, configure, compile, and run a CUTLASS Conv2d using the Python interface + +* [04_epilogue_visitor](/examples/python/04_epilogue_visitor.ipynb) + + Shows how to fuse elementwise activation functions to GEMMs via the Python Epilogue Visitor interface diff --git a/include/cute/algorithm/axpby.hpp b/include/cute/algorithm/axpby.hpp new file mode 100644 index 0000000000..339743f491 --- /dev/null +++ b/include/cute/algorithm/axpby.hpp @@ -0,0 +1,95 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include + +namespace cute +{ + +// +// Accept mutable temporaries +// +template +CUTE_HOST_DEVICE +void +axpby(Alpha const& alpha, + Tensor const& x, + Beta const& beta, + Tensor && y, + PrdTensor const& p = {}) +{ + return axpby(alpha, x, beta, y, p); +} + +// +// AXPBY +// +template +CUTE_HOST_DEVICE +void +axpby(Alpha const& alpha, + Tensor const& x, + Beta const& beta, + Tensor & y, + PrdTensor const& p = {}) +{ + auto isBetaZero = [&] () { + if constexpr (is_complex::value) { + return beta.real() == Int<0>{} && beta.imag() == Int<0>{}; + } + else { + return beta == Int<0>{}; + } + + CUTE_GCC_UNREACHABLE; + } (); + + CUTE_UNROLL + for (int i = 0; i < size(x); ++i) { + if (p(i)) { + y(i) = (isBetaZero ? alpha * x(i) : alpha * x(i) + beta * y(i)); + } + } +} + +} // end namespace cute diff --git a/include/cute/algorithm/clear.hpp b/include/cute/algorithm/clear.hpp new file mode 100644 index 0000000000..0b3a8eaa1d --- /dev/null +++ b/include/cute/algorithm/clear.hpp @@ -0,0 +1,64 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTE_HOST_DEVICE +#include // cute::Tensor +#include // cute::fill + +namespace cute +{ + +// +// Accept mutable temporaries +// +template +CUTE_HOST_DEVICE +void +clear(Tensor&& tensor) +{ + return clear(tensor); +} + +// +// Set elements to zero +// +template +CUTE_HOST_DEVICE +void +clear(Tensor& tensor) +{ + using T = typename Tensor::value_type; + + fill(tensor, T{}); +} + +} // end namespace cute diff --git a/include/cute/algorithm/cooperative_copy.hpp b/include/cute/algorithm/cooperative_copy.hpp new file mode 100644 index 0000000000..c9e02245e2 --- /dev/null +++ b/include/cute/algorithm/cooperative_copy.hpp @@ -0,0 +1,339 @@ +/*************************************************************************************************** +* Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. +* SPDX-License-Identifier: BSD-3-Clause +* +* Redistribution and use in source and binary forms, with or without +* modification, are permitted provided that the following conditions are met: +* +* 1. Redistributions of source code must retain the above copyright notice, this +* list of conditions and the following disclaimer. +* +* 2. Redistributions in binary form must reproduce the above copyright notice, +* this list of conditions and the following disclaimer in the documentation +* and/or other materials provided with the distribution. +* +* 3. Neither the name of the copyright holder nor the names of its +* contributors may be used to endorse or promote products derived from +* this software without specific prior written permission. +* +* THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" +* AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE +* IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE +* DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE +* FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL +* DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR +* SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER +* CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, +* OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE +* OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. +* +**************************************************************************************************/ +#pragma once + +#include +#include +#include // cute::logical_divide +#include // cute::Swizzle +#include // cute::get_nonswizzle_portion +#include // cute::Tensor +#include +#include +#include + +namespace cute +{ + +template +CUTE_HOST_DEVICE void +naive_cooperative_copy(uint32_t const& tid, + Tensor const& src, + Tensor & dst) +{ + auto N = size(dst); + auto R = N % Int{}; + if (R > 0 && tid < R) { // Likely static condition && Residue in-bounds + dst[tid] = src[tid]; + } + CUTE_UNROLL + for (uint32_t i = uint32_t(R); i < uint32_t(N); i += NumThreads) { // All in-bounds + dst[tid + i] = src[tid + i]; + } +} + +// Accept mutable temporaries +template +CUTE_HOST_DEVICE void +naive_cooperative_copy(uint32_t const& tid, + Tensor const& src, + Tensor && dst) +{ + return naive_cooperative_copy(tid, src, dst); +} + +// A heuristic to determine a "good" permutation of two tensors for later vectorization and thr-assignment +template +CUTE_HOST_DEVICE constexpr +auto +heuristic_permutation(Tensor const& a, + Tensor const& b) +{ + constexpr bool swizzleA = get_swizzle_t::num_bits != 0 or + get_swizzle_t::num_bits != 0; + constexpr bool swizzleB = get_swizzle_t::num_bits != 0 or + get_swizzle_t::num_bits != 0; + auto a_inv = right_inverse(get_nonswizzle_portion(a.layout())); + auto b_inv = right_inverse(get_nonswizzle_portion(b.layout())); + + constexpr uint8_t scoreA = (uint8_t(swizzleA) << 2) | + (uint8_t(is_smem::value) << 1) | + (uint8_t(size(a_inv) > size(b_inv)) << 0); + + constexpr uint8_t scoreB = (uint8_t(swizzleB) << 2) | + (uint8_t(is_smem::value) << 1) | + (uint8_t(size(b_inv) > size(a_inv)) << 0); + + if constexpr (scoreA >= scoreB) { + return a_inv; + } else { + return b_inv; + } +} + +// cooperative_copy(thr_idx, src, dst) +// Use NumThreads to copy Tensor src to Tensor dst with element-wise vectorization up to MaxVecBits. +// @pre 0 <= @a tid < NumThreads +// @pre Tensors @a src and @a dst are aligned up to MaxVecBits. +// That is, pointers and dynamic strides are assumed to be aligned up to MaxVecBits. +// +template +CUTE_HOST_DEVICE +void +cooperative_copy(uint32_t const& tid, + Tensor const& src, + Tensor & dst, + CopyPolicy const& cpy = {}) +{ + // Assumes the shapes are static, can generalize/fallback + CUTE_STATIC_ASSERT_V(is_static{} && is_static{}); + CUTE_STATIC_ASSERT_V(size(src) == size(dst)); + // Assumes the types are the same, can generalize/fallback + static_assert(cute::is_same::value); + static_assert(MaxVecBits == sizeof_bits_v || + MaxVecBits == 8 || MaxVecBits == 16 || MaxVecBits == 32 || MaxVecBits == 64 || MaxVecBits == 128, + "Expected MaxVecBits to be value size or 8 or 16 or 32 or 64 or 128 for alignment and performance."); + // Check that the tensors are likely shared across threads: either gmem or smem + static_assert((is_gmem::value || is_smem::value), + "cooperative_copy expects shared gmem or smem source tensor."); + static_assert((is_gmem::value || is_smem::value), + "cooperative_copy expects shared gmem or smem destination tensor."); + // Precondition on tid in DEBUG + assert(tid < NumThreads); + // Precondition on pointer alignment in DEBUG + assert(is_byte_aligned(raw_pointer_cast(src.data()))); + assert(is_byte_aligned(raw_pointer_cast(dst.data()))); + +#if 0 + if (thread0()) { + print(" "); print("cooperative_copy\n"); + print(" "); print("NumThreads: "); print(NumThreads); print("\n"); + print(" "); print("MaxVecBits: "); print(MaxVecBits); print("\n"); + print(" "); print("src: "); print(src); print("\n"); + print(" "); print("dst: "); print(dst); print("\n"); + } +#ifdef __CUDA_ARCH__ + __syncthreads(); +#endif +#endif + + // The common layout of the two tensors that can be vectorized over elements and threads + // vidx -> coord + auto common_layout = heuristic_permutation(src, dst); + + // Apply + // (V, rest) + Tensor src_a = coalesce(logical_divide(src, common_layout), Shape<_1,_1>{}); + Tensor dst_a = coalesce(logical_divide(dst, common_layout), Shape<_1,_1>{}); + + // + // Determine vectorization of elems and thrs based on src/dst size and number of threads + // NOTE: This heuristic promotes parallelization over vectorization + // + + // The number of elements and number of bits + constexpr int elem_bits = sizeof_bits_v; + constexpr int total_elem = size(SrcLayout{}); + + // The number of elements that can be vectorized in values + constexpr int common_elem = decltype(max_common_vector(src_a, dst_a))::value; + +#if 0 + if (thread0()) { + print(" "); print("common_layout: "); print(common_layout); print("\n"); + print(" "); print("src_a: "); print(src_a); print("\n"); + print(" "); print("dst_a: "); print(dst_a); print("\n"); + } +#ifdef __CUDA_ARCH__ + __syncthreads(); +#endif +#endif + + // + if constexpr (total_elem % NumThreads != 0) { + // Not attempting to find a partitioning pattern, fallback to dynamically indexed slowpath + + if constexpr (common_elem > 1 && MaxVecBits > elem_bits) { + // If the vectorization is non-trivial and divides the maximum vectorizations, then vectorize + constexpr auto max_align_src = elem_bits * decltype(max_alignment(src_a.layout()))::value; + constexpr auto max_align_dst = elem_bits * decltype(max_alignment(dst_a.layout()))::value; + constexpr auto vec_bits = gcd(max_align_src, max_align_dst, MaxVecBits); + using VecType = uint_bit_t; + + static_assert(vec_bits % elem_bits == 0, "Expected divisibility"); + static_assert((vec_bits >= 8), "No support for subbyte copying"); + + Tensor src_v = recast(src_a); + Tensor dst_v = recast(dst_a); + +#if 0 + if (thread0()) { + print(" "); print("cooperative_copy -- naive\n"); + print(" "); print("src_v: "); print(src_v); print("\n"); + print(" "); print("dst_v: "); print(dst_v); print("\n"); + } +#ifdef __CUDA_ARCH__ + __syncthreads(); +#endif +#endif + + naive_cooperative_copy(tid, src_v, dst_v); + } else { + naive_cooperative_copy(tid, src_a, dst_a); + } + } else { + // If the tensors can be equally partitioned by the threads, + // compute vectorization widths in elements and threads. + + // If there are too many threads to allow a full vectorized copy, trunc the vectorization + constexpr int total_bits = total_elem * elem_bits; + constexpr int max_bits_per_thr = total_bits / NumThreads; + // At least elem_bits, at most common_bits + constexpr int common_bits = common_elem * elem_bits; + constexpr int vec_bits = cute::max(elem_bits, cute::gcd(common_bits, int(MaxVecBits), max_bits_per_thr)); + + // Should account for vec_bits < 8 and/or vec_elem <= 1 + // And also account for subbyte types, which could cause race conditions + // Want to ENFORCE sufficient vectorization in those cases + static_assert(vec_bits % elem_bits == 0, "Expected divisibility"); + static_assert(vec_bits >= 8, "No support for subbyte copying"); + + using VecType = uint_bit_t; + constexpr int vec_elem = vec_bits / elem_bits; + + constexpr int vec_thrs = cute::min(int(NumThreads), total_elem / vec_elem); + + // + // Determine the partitioning patterns for the vec_elems and vec_thrs + // + + // Distribute the rest of the V*T to some consistent portion outside of the common_layout, if needed + auto common_domain_src = domain_distribute(shape(src_a), Int{}); + auto common_domain_dst = domain_distribute(shape(dst_a), Int{}); + + // Make sure for now, could fall back here instead + CUTE_STATIC_ASSERT_V(size(common_domain_src) == Int{}); + CUTE_STATIC_ASSERT_V(compatible(common_domain_src, common_domain_dst) || + compatible(common_domain_dst, common_domain_src)); + // Use the "more specific" domain for the extra elements of V*T + auto common_domain = conditional_return(compatible(common_domain_src, common_domain_dst), + common_domain_dst, common_domain_src); + + // Construct the tiler + auto tiler_vt = common_domain.with_shape(Int{}, Int{}); + + // Apply and slice + Tensor src_v = logical_divide(src_a, tiler_vt)(make_coord(_,tid),_); + Tensor dst_v = logical_divide(dst_a, tiler_vt)(make_coord(_,tid),_); + +#if 0 + if (thread0()) { + print(" "); print("cooperative_copy -- vec\n"); + print(" "); print("Used vector: "); print(vec_elem); print("\n"); + print(" "); print("Used threads: "); print(vec_thrs); print("\n"); + print(" "); print("tiler_vt: "); print(tiler_vt); print("\n"); + print(" "); print("src_v: "); print(src_v); print("\n"); + print(" "); print("dst_v: "); print(dst_v); print("\n"); + print(" "); print("recast(src_v): "); print(recast(src_v)); print("\n"); + print(" "); print("recast(dst_v): "); print(recast(dst_v)); print("\n"); + } +#ifdef __CUDA_ARCH__ + __syncthreads(); +#endif +#endif + + // If we're using all threads (static) or the tid is in-range (dynamic) + if (vec_thrs == NumThreads or tid < vec_thrs) { + auto src_c = recast(src_v); + auto dst_c = recast(dst_v); + return copy(cpy, src_c, dst_c); + } + } +} + + +// Default max-vectorization size to value_type size +template +CUTE_HOST_DEVICE +void +cooperative_copy(uint32_t const& tid, + Tensor const& src, + Tensor & dst, + CopyPolicy const& cpy = {}) +{ + constexpr uint32_t MaxVecBits = sizeof_bits_v; + return cooperative_copy(tid, src, dst, cpy); +} + +// +// Accept mutable temporaries +// + +template +CUTE_HOST_DEVICE +void +cooperative_copy(uint32_t const& tid, + Tensor const& src, + Tensor && dst, + CopyPolicy const& cpy = {}) +{ + return cooperative_copy(tid, src, dst, cpy); +} + +template +CUTE_HOST_DEVICE +void +cooperative_copy(uint32_t const& tid, + Tensor const& src, + Tensor && dst, + CopyPolicy const& cpy = {}) +{ + return cooperative_copy(tid, src, dst, cpy); +} + +} // end namespace cute diff --git a/include/cute/algorithm/cooperative_gemm.hpp b/include/cute/algorithm/cooperative_gemm.hpp new file mode 100644 index 0000000000..e4bd5ea628 --- /dev/null +++ b/include/cute/algorithm/cooperative_gemm.hpp @@ -0,0 +1,585 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include + +#include + +#include +#include +#include + +#include + +namespace cute +{ + +// +// Cooperative Shared-Memory GEMMs +// + +namespace detail { + +// Slow fallback path: +template +CUTE_HOST_DEVICE +void +epilogue_predication(ThrMMA const& thr_mma, + Alpha const& alpha, + Tensor & tCrC, + Beta const& beta, + Tensor & sC, + Tensor & tCsC, + CLoadTransformOp const& sC_load_op, // transforms C values before use in GEMM + CStoreTransformOp const& sC_store_op) // transforms results before they are stored to C +{ + using InputTypeC = typename TSC::value_type; + using ComputeTypeC = typename ThrMMA::ValTypeC; + CUTE_STATIC_ASSERT(CUTE_STL_NAMESPACE::is_same_v); + + // Create coordinate tensors for the problem + Tensor cC = make_identity_tensor(shape(sC)); // (M,N) -> (m,n) + // Repeat partitioning with thr_mma + Tensor tCcC = thr_mma.partition_C(cC); // (MMA,MMA_M,MMA_N) -> (m,n) + + const bool isBetaZero = [&] () { + if constexpr (is_complex::value) { + return beta.real() == Int<0>{} && beta.imag() == Int<0>{}; + } + else { + return beta == Int<0>{}; + } + CUTE_GCC_UNREACHABLE; + } (); + + // Custom axpby_if for now + CUTE_UNROLL + for (int i = 0; i < size(tCrC); ++i) + { + if (elem_less(tCcC(i), shape(sC))) + { + tCsC(i) = sC_store_op(isBetaZero ? alpha * tCrC(i) + : alpha * tCrC(i) + + beta * static_cast(sC_load_op(tCsC(i)))); + } + } +} + +template +CUTE_HOST_DEVICE +void +epilogue_no_predication(Alpha const& alpha, + Tensor & tCrC, + Beta const& beta, + Tensor & tCsC, + CLoadTransformOp const& sC_load_op, // transforms C values before use in GEMM + CStoreTransformOp const& sC_store_op, // transforms results before they are stored to C + SmemCopyOpC const& sC_copy_op) +{ + using InputTypeC = typename TSC::value_type; + using ComputeTypeC = typename TRC::value_type; + + const bool isBetaZero = [&] () { + if constexpr (is_complex::value) { + return beta.real() == Int<0>{} && beta.imag() == Int<0>{}; + } + else { + return beta == Int<0>{}; + } + CUTE_GCC_UNREACHABLE; + } (); + + Tensor tCrDi = make_fragment_like(tCsC); + Tensor tCrD = make_fragment_like(tCrC); + if(!isBetaZero) { + copy(sC_copy_op, tCsC, tCrDi); + // Transform C on/after load + cute::transform(tCrDi, tCrD, sC_load_op); + } + // C = alpha * (A * B) + beta * C + axpby(alpha, tCrC, beta, tCrD); + // Transform C before/on store + cute::transform(tCrD, tCrDi, sC_store_op); + copy(sC_copy_op, tCrDi, tCsC); +} + +// Predicated Cooperative GEMM +template +CUTE_HOST_DEVICE +void +cooperative_gemm_predication(ThrMMA const& thr_mma, + Tensor const& sA, + Tensor const& sB, + Tensor & tCrC, + ALoadTransformOp const& sA_load_op, // transforms A values before use in GEMM + BLoadTransformOp const& sB_load_op) // transforms B values before use in GEMM +{ + using InputTypeA = typename TA::value_type; + using InputTypeB = typename TB::value_type; + using InputTypeC = typename TC::value_type; + using ComputeTypeA = typename ThrMMA::ValTypeA; + using ComputeTypeB = typename ThrMMA::ValTypeB; + using ComputeTypeC = typename ThrMMA::ValTypeC; + + // + // MMA Partitioning + // + + // Partition the sA, sB, and sC tiles across the threads for the MMA + Tensor tCsA = thr_mma.partition_A(sA); // (MMA,MMA_M,MMA_K) + Tensor tCsB = thr_mma.partition_B(sB); // (MMA,MMA_N,MMA_K) + + // Create register tensors for the MMA to operate on + Tensor tCrA = thr_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K) + Tensor tCrB = thr_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K) + +#if 0 + if (thread0()) { + print(" sA: "); print( sA); print("\n"); + print(" sB: "); print( sB); print("\n"); + print(thr_mma); + print("tCsA: "); print(tCsA); print("\n"); + print("tCsB: "); print(tCsB); print("\n"); + print("tCrA: "); print(tCrA); print("\n"); + print("tCrB: "); print(tCrB); print("\n"); + print("tCrC: "); print(tCrC); print("\n"); + } +#endif + + // + // PREDICATION + // + + // Create coordinate tensors for the problem + Tensor cA = make_identity_tensor(shape(sA)); // (M,K) -> (m,k) + Tensor cB = make_identity_tensor(shape(sB)); // (N,K) -> (n,k) + + // Repeat partitioning with thr_mma + Tensor tCcA = thr_mma.partition_A(cA); // (MMA,MMA_M,MMA_K) -> (m,k) + Tensor tCcB = thr_mma.partition_B(cB); // (MMA,MMA_N,MMA_K) -> (n,k) + + // Allocate the preds for MMA- and MMA_MN-modes + Tensor tCpA = make_tensor(make_shape(size<0>(tCsA), size<1>(tCsA))); + Tensor tCpB = make_tensor(make_shape(size<0>(tCsB), size<1>(tCsB))); + + // Populate the predicates on M and N + CUTE_UNROLL + for (int i = 0; i < size(tCpA); ++i) { + tCpA(i) = elem_less(get<0>(tCcA(_,_,Int<0>{})(i)), shape<0>(sA)); + } + CUTE_UNROLL + for (int i = 0; i < size(tCpB); ++i) { + tCpB(i) = elem_less(get<0>(tCcB(_,_,Int<0>{})(i)), shape<0>(sB)); + } + +#if 0 + if (thread0()) { + print(" cA: "); print( cA); print("\n"); + print(" cB: "); print( cB); print("\n"); + print("tCcA: "); print(tCcA); print("\n"); + print("tCcB: "); print(tCcB); print("\n"); + print_tensor(tCpA); + print_tensor(tCpB); + } +#endif + + // + // PREFETCH k_block = 0 + // Condition the k-predication on (static) k_block == K_BLOCK_MAX-1, the last k_block + // Assumes the MMA-tiling in K is trivial + // + + constexpr int K_BLOCK_MAX = size<2>(tCrA); + + CUTE_UNROLL + for (int m = 0; m < size<1>(tCrA); ++m) { // Copy MMA_M + CUTE_UNROLL + for (int i = 0; i < size<0>(tCrA); ++i) { // Copy MMA_I + tCrA(i,m,0) = (tCpA(i,m) && (0 < K_BLOCK_MAX-1 || elem_less(get<1>(tCcA(i,m,0)), shape<1>(sA)))) ? static_cast(sA_load_op(tCsA(i,m,0))) : ComputeTypeA{}; + } + } + CUTE_UNROLL + for (int n = 0; n < size<1>(tCrB); ++n) { // Copy MMA_N + CUTE_UNROLL + for (int i = 0; i < size<0>(tCrB); ++i) { // Copy MMA_I + tCrB(i,n,0) = (tCpB(i,n) && (0 < K_BLOCK_MAX-1 || elem_less(get<1>(tCcB(i,n,0)), shape<1>(sB)))) ? static_cast(sB_load_op(tCsB(i,n,0))) : ComputeTypeB{}; + } + } + // + // MAINLOOP + // + + CUTE_UNROLL + for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) + { + if (k_block < K_BLOCK_MAX-1) // static-if not the last k_block + { + int k_next = k_block + 1; // Load k_next block + + // Condition the k-predication on (static) k_block == K_BLOCK_MAX-1, the last k_block + // Assumes the MMA-tiling in K is trivial + + CUTE_UNROLL + for (int m = 0; m < size<1>(tCrA); ++m) { // Copy MMA_M + CUTE_UNROLL + for (int i = 0; i < size<0>(tCrA); ++i) { // Copy MMA_I + tCrA(i,m,k_next) = (tCpA(i,m) && (k_next < K_BLOCK_MAX-1 || elem_less(get<1>(tCcA(i,m,k_next)), shape<1>(sA)))) ? static_cast(sA_load_op(tCsA(i,m,k_next))) : ComputeTypeA{}; + } + } + CUTE_UNROLL + for (int n = 0; n < size<1>(tCrB); ++n) { // Copy MMA_N + CUTE_UNROLL + for (int i = 0; i < size<0>(tCrB); ++i) { // Copy MMA_I + tCrB(i,n,k_next) = (tCpB(i,n) && (k_next < K_BLOCK_MAX-1 || elem_less(get<1>(tCcB(i,n,k_next)), shape<1>(sB)))) ? static_cast(sB_load_op(tCsB(i,n,k_next))) : ComputeTypeB{}; + } + } + } + // GEMM on k_block in registers + gemm(thr_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); + } +} + +// Unpredicated Cooperative GEMM +template +CUTE_HOST_DEVICE +void +cooperative_gemm_no_predication(uint32_t thread_idx, + ThrMMA const& thr_mma, + Tensor const& sA, + Tensor const& sB, + Tensor & tCrC, + ALoadTransformOp const& sA_load_op, // transforms A values before use in GEMM + BLoadTransformOp const& sB_load_op, // transforms B values before use in GEMM + SmemCopyOpA const& sA_copy_op, + SmemCopyOpB const& sB_copy_op) +{ + using InputTypeA = typename TA::value_type; + using InputTypeB = typename TB::value_type; + using InputTypeC = typename TC::value_type; + using ComputeTypeA = typename ThrMMA::ValTypeA; + using ComputeTypeB = typename ThrMMA::ValTypeB; + using ComputeTypeC = typename ThrMMA::ValTypeC; + + + // + // MMA Partitioning + // + + // Create register tensors for the MMA to operate on + Tensor tCrA = thr_mma.partition_fragment_A(sA); // (MMA,MMA_M,MMA_K) + Tensor tCrB = thr_mma.partition_fragment_B(sB); // (MMA,MMA_N,MMA_K) + + using CopyOpAType = SmemCopyOpA; + using CopyOpBType = SmemCopyOpB; + + auto smem_tiled_copy_A = make_tiled_copy_A(Copy_Atom{}, thr_mma); + auto smem_thr_copy_A = smem_tiled_copy_A.get_thread_slice(thread_idx); + Tensor tCsA = smem_thr_copy_A.partition_S(sA); + Tensor tCrAi = make_fragment_like(tCsA); + Tensor tCrAi_copy_view = smem_thr_copy_A.retile_D(tCrAi); + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(tCrAi_copy_view)); // CPY_M + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCrAi_copy_view)); // CPY_K + + auto smem_tiled_copy_B = make_tiled_copy_B(Copy_Atom{}, thr_mma); + auto smem_thr_copy_B = smem_tiled_copy_B.get_thread_slice(thread_idx); + Tensor tCsB = smem_thr_copy_B.partition_S(sB); + Tensor tCrBi = make_fragment_like(tCsB); + Tensor tCrBi_copy_view = smem_thr_copy_B.retile_D(tCrBi); + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<1>(tCrBi_copy_view)); // CPY_N + CUTE_STATIC_ASSERT_V(size<2>(tCsB) == size<2>(tCrBi_copy_view)); // CPY_K + +#if 0 + if (thread0()) { + print(" sA: "); print(sA); print("\n"); + print(" sB: "); print(sB); print("\n"); + print(thr_mma); print("\n"); + print("tCrA: "); print(tCrA); print("\n"); + print("tCrB: "); print(tCrB); print("\n"); + print("tCrC: "); print(tCrC); print("\n"); + print(smem_thr_copy_A); print("\n"); + print("tCsA: "); print(tCsA); print("\n"); + print("tCrA_copy_view: "); print(tCrA_copy_view); print("\n"); + print(smem_thr_copy_B); print("\n"); + print("tCsB: "); print(tCsB); print("\n"); + print("tCrB_copy_view: "); print(tCrB_copy_view); print("\n"); + } +#endif + + // + // PREFETCH + // + + copy(smem_tiled_copy_A, tCsA(_,_,Int<0>{}), tCrAi_copy_view(_,_,Int<0>{})); + copy(smem_tiled_copy_B, tCsB(_,_,Int<0>{}), tCrBi_copy_view(_,_,Int<0>{})); + // + // MAINLOOP + // + + constexpr int K_BLOCK_MAX = size<2>(tCrA); + + CUTE_UNROLL + for (int k_block = 0; k_block < K_BLOCK_MAX; ++k_block) + { + // static-if load the next k_block. No k-predication required on these loads. + if (k_block < K_BLOCK_MAX-1) + { + // Load the next k_block + int k_next = k_block + 1; // statically unrolled + copy(smem_tiled_copy_A, tCsA(_,_,k_next), tCrAi_copy_view(_,_,k_next)); + copy(smem_tiled_copy_B, tCsB(_,_,k_next), tCrBi_copy_view(_,_,k_next)); + } + + // Transform A and B, relying on the compiler to remove in case of identity ops + cute::transform(tCrAi(_,_,k_block), tCrA(_,_,k_block), sA_load_op); + cute::transform(tCrBi(_,_,k_block), tCrB(_,_,k_block), sB_load_op); + + // GEMM on k_block in registers + gemm(thr_mma, tCrA(_,_,k_block), tCrB(_,_,k_block), tCrC); + } +} + +} // end namespace detail + +// C passed as a shared memory tensor +// Epilogue included +template +CUTE_HOST_DEVICE +void +cooperative_gemm(uint32_t thread_idx, + TiledMMA const& tiled_mma, + Alpha const& alpha, + Tensor const& sA, + Tensor const& sB, + Beta const& beta, + Tensor & sC, + ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM + BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM + CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM + CStoreTransformOp const& sC_store_op = {}, // transforms results before they are stored to C + SmemCopyOpA const& sA_copy_op = {}, + SmemCopyOpB const& sB_copy_op = {}, + SmemCopyOpC const& sC_copy_op = {}) +{ + CUTE_STATIC_ASSERT_V(rank(sA) == Int<2>{}); + CUTE_STATIC_ASSERT_V(rank(sB) == Int<2>{}); + CUTE_STATIC_ASSERT_V(rank(sC) == Int<2>{}); + + CUTE_STATIC_ASSERT_V(size<0>(sA) == size<0>(sC)); // AM == CM + CUTE_STATIC_ASSERT_V(size<0>(sB) == size<1>(sC)); // BN == CN + CUTE_STATIC_ASSERT_V(size<1>(sA) == size<1>(sB)); // AK == BK + + using InputTypeA = typename TA::value_type; + using InputTypeB = typename TB::value_type; + using InputTypeC = typename TC::value_type; + using ComputeTypeA = typename TiledMMA::ValTypeA; + using ComputeTypeB = typename TiledMMA::ValTypeB; + using ComputeTypeC = typename TiledMMA::ValTypeC; + + auto compat = evenly_divides(make_shape(size<0>(sA), size<0>(sB), size<1>(sA)), + tile_shape(TiledMMA{})); + + // ThrMMA + auto thr_mma = tiled_mma.get_thread_slice(thread_idx); + Tensor tCsC = thr_mma.partition_C(sC); // (MMA,MMA_M,MMA_N) :: InputTypeC + Tensor tCrC = thr_mma.make_fragment_C(tCsC); // (MMA,MMA_M,MMA_N) :: ComputeTypeC + + // Clear accumulators + clear(tCrC); + +#if 0 + if (thread0()) { + print(" sC: "); print(sC); print("\n"); + print(" tCsC: "); print(tCsC); print("\n"); + } +#endif + + if constexpr (is_constant::value) { + detail::cooperative_gemm_no_predication( + thread_idx, thr_mma, sA, sB, tCrC, sA_load_op, sB_load_op, sA_copy_op, sB_copy_op + ); + detail::epilogue_no_predication( + alpha, tCrC, beta, tCsC, sC_load_op, sC_store_op, sC_copy_op + ); + } else { + detail::cooperative_gemm_predication( + thr_mma, sA, sB, tCrC, sA_load_op, sB_load_op + ); + detail::epilogue_predication( + thr_mma, alpha, tCrC, beta, sC, tCsC, sC_load_op, sC_store_op + ); + } +} + +// C already partitioned into registers on input +// It can be passed non-empty +// Epilogue not included +template +CUTE_HOST_DEVICE +void +cooperative_gemm(uint32_t thread_idx, + TiledMMA const& tiled_mma, + Tensor const& sA, + Tensor const& sB, + Tensor & tCrC, + ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM + BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM + SmemCopyOpA const& sA_copy_op = {}, + SmemCopyOpB const& sB_copy_op = {}) +{ + CUTE_STATIC_ASSERT_V(rank(sA) == Int<2>{}); + CUTE_STATIC_ASSERT_V(rank(sB) == Int<2>{}); + + CUTE_STATIC_ASSERT_V(size<1>(sA) == size<1>(sB)); // AK == BK + + using InputTypeA = typename TA::value_type; + using InputTypeB = typename TB::value_type; + using InputTypeC = typename TC::value_type; + using ComputeTypeA = typename TiledMMA::ValTypeA; + using ComputeTypeB = typename TiledMMA::ValTypeB; + using ComputeTypeC = typename TiledMMA::ValTypeC; + + // Check if input C fragment is compatible with thr_mma and problem size + using ref_c_frag = decltype(partition_shape_C(tiled_mma, make_shape(size<0>(sA), size<0>(sB)))); + CUTE_STATIC_ASSERT_V(compatible(shape(ref_c_frag{}), shape(tCrC))); + + auto compat = evenly_divides(make_shape(size<0>(sA), size<0>(sB), size<1>(sA)), + tile_shape(TiledMMA{})); + + // ThrMMA + auto thr_mma = tiled_mma.get_thread_slice(thread_idx); + + if constexpr (is_constant::value) { + detail::cooperative_gemm_no_predication( + thread_idx, thr_mma, sA, sB, tCrC, sA_load_op, sB_load_op, sA_copy_op, sB_copy_op + ); + } else { + detail::cooperative_gemm_predication( + thr_mma, sA, sB, tCrC, sA_load_op, sB_load_op + ); + } +} + +// Accept mutable temporaries +template +CUTE_HOST_DEVICE +void +cooperative_gemm(uint32_t thread_idx, + TiledMMA const& tiled_mma, + Alpha const& alpha, + Tensor const& sA, + Tensor const& sB, + Beta const& beta, + Tensor && sC, + ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM + BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM + CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM + CStoreTransformOp const& sC_store_op = {}, // transforms results before they are stored to C + SmemCopyOpA const& sA_copy_op = {}, + SmemCopyOpB const& sB_copy_op = {}, + SmemCopyOpC const& sC_copy_op = {}) +{ + cooperative_gemm(thread_idx, tiled_mma, alpha, sA, sB, beta, sC, + sA_load_op, sB_load_op, sC_load_op, sC_store_op, + sA_copy_op, sB_copy_op, sC_copy_op); +} + +// Legacy overload of cute::gemm for backwards-compatibility +template +CUTE_HOST_DEVICE +void +gemm(ThrMMA const& thr_mma, + Alpha const& alpha, + Tensor const& sA, + Tensor const& sB, + Beta const& beta, + Tensor & sC, + ALoadTransformOp const& sA_load_op = {}, // transforms A values before use in GEMM + BLoadTransformOp const& sB_load_op = {}, // transforms B values before use in GEMM + CLoadTransformOp const& sC_load_op = {}, // transforms C values before use in GEMM + CStoreTransformOp const& sC_store_op = {}) // transforms results before they are stored to C +{ + CUTE_STATIC_ASSERT_V(rank(sA) == Int<2>{}); + CUTE_STATIC_ASSERT_V(rank(sB) == Int<2>{}); + CUTE_STATIC_ASSERT_V(rank(sC) == Int<2>{}); + + CUTE_STATIC_ASSERT_V(size<0>(sA) == size<0>(sC)); // AM == CM + CUTE_STATIC_ASSERT_V(size<0>(sB) == size<1>(sC)); // BN == CN + CUTE_STATIC_ASSERT_V(size<1>(sA) == size<1>(sB)); // AK == BK + + Tensor tCsC = thr_mma.partition_C(sC); // (MMA,MMA_M,MMA_N) + Tensor tCrC = thr_mma.make_fragment_C(tCsC); // (MMA,MMA_M,MMA_N) + + // Goes directly to the slow path to avoid getting thread_idx from thr_mma + detail::cooperative_gemm_predication( + thr_mma, sA, sB, sC, sA_load_op, sB_load_op + ); + + detail::epilogue_predication( + thr_mma, alpha, tCrC, beta, sC, tCsC, sC_load_op, sC_store_op + ); +} + +} // end namespace cute diff --git a/include/cute/algorithm/copy.hpp b/include/cute/algorithm/copy.hpp new file mode 100644 index 0000000000..84ef49161d --- /dev/null +++ b/include/cute/algorithm/copy.hpp @@ -0,0 +1,545 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTE_HOST_DEVICE +#include // cute::Tensor +#include // cute::TrivialPredTensor +#include // cute::Copy_Atom + +namespace cute +{ + +// +// copy_if -- Predicated Copy +// + +template +CUTE_HOST_DEVICE +void +copy_if(PrdTensor const& pred, + Tensor const& src, + Tensor & dst) +{ + using SrcType = typename SrcEngine::value_type; + using DstType = typename DstEngine::value_type; + + CUTE_UNROLL + for (int i = 0; i < size(dst); ++i) { + if (pred(i)) { + dst(i) = static_cast(static_cast(src(i))); + } + } +} + +// +// copy_if -- Predicated CopyAtom +// + +template +CUTE_HOST_DEVICE +void +copy_if(Copy_Atom const& copy_atom, + PredTensor const& pred, // (Rest...) + Tensor const& src, // (V,Rest...) + Tensor & dst) // (V,Rest...) +{ + static_assert(SrcLayout::rank == DstLayout::rank, "CopyAtom rank-mismatch."); + auto has_with_bool = cute::is_valid([](auto t)->void_t().with(true))>{}, copy_atom); + + if constexpr (SrcLayout::rank == 1) { // Dispatch the copy + if constexpr (has_with_bool) { + copy_atom.with(pred()).call(src, dst); + } else { + if (pred()) { copy_atom.call(src, dst); } + } + } else { // Loop over all but the first mode + constexpr int R = SrcLayout::rank; + Tensor src_v = group_modes<1,R>(src); + Tensor dst_v = group_modes<1,R>(dst); + CUTE_UNROLL + for (int i = 0; i < size<1>(dst_v); ++i) { + if constexpr (has_with_bool) { + copy_atom.with(pred(i)).call(src_v(_,i), dst_v(_,i)); + } else { + if (pred(i)) { copy_atom.call(src_v(_,i), dst_v(_,i)); } + } + } + } +} + +// +// copy_if -- AutoCopyAsync +// +template +CUTE_HOST_DEVICE +void +copy_if(AutoCopyAsync const& cpy, + PrdTensor const& pred, + Tensor const& src, + Tensor & dst) +{ + using SrcElemWithConst = remove_reference_t; + using SrcType = typename SrcEngine::value_type; + using DstType = typename DstEngine::value_type; + + auto copy_op = []() { +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) + if constexpr (is_gmem::value && is_smem::value && + sizeof(SrcType) == sizeof(DstType)) { + if constexpr (is_const_v && sizeof(SrcType) == 16) { + return SM80_CP_ASYNC_CACHEGLOBAL{}; + } else if constexpr (sizeof(SrcType) == 4 || sizeof(SrcType) == 8 || sizeof(SrcType) == 16) { + return SM80_CP_ASYNC_CACHEALWAYS{}; + } else { + return UniversalCopy{}; + } + } else { + return UniversalCopy{}; + } + + CUTE_GCC_UNREACHABLE; +#else + return UniversalCopy{}; +#endif + }(); + + CUTE_UNROLL + for (int i = 0; i < size(dst); ++i) { + if (pred(i)) { + copy_op.copy(src(i), dst(i)); + } + } +} + +// +// copy -- AutoCopyAsync +// + +template +CUTE_HOST_DEVICE +void +copy(AutoCopyAsync const& cpy, + Tensor const& src, // (V,Rest...) + Tensor & dst) // (V,Rest...) +{ + copy_if(cpy, TrivialPredTensor{}, src, dst); +} + +// +// copy -- CopyAtom +// + +template +CUTE_HOST_DEVICE +void +copy(Copy_Atom const& copy_atom, + Tensor const& src, // (V,Rest...) + Tensor & dst) // (V,Rest...) +{ + static_assert(SrcLayout::rank == DstLayout::rank, "CopyAtom rank-mismatch."); + + if constexpr (SrcLayout::rank == 1) { // Dispatch the copy + copy_atom.call(src, dst); + } else { // Loop over all but the first mode + constexpr int R = SrcLayout::rank; + Tensor src_v = group_modes<1,R>(src); + Tensor dst_v = group_modes<1,R>(dst); + + if constexpr (is_static::value && is_static::value) { + CUTE_STATIC_ASSERT_V(size<1>(src_v) == size<1>(dst_v)); + + // AutoFilter on the Rest-mode + auto dst_null = nullspace(layout<1>(dst_v)); + + Tensor dst_n = zipped_divide(dst_v, make_tile(shape<0>(dst_v), dst_null)); // ((V, NLL), (_1, Rest)) + Tensor src_n = zipped_divide(src_v, make_tile(shape<0>(src_v), dst_null)); // ((V, NLL), (_1, Rest)) + + CUTE_STATIC_ASSERT_V(size<1>(src_n) == size<1>(dst_n)); + CUTE_STATIC_ASSERT_V((cosize<0,1>(dst_n.layout()) == Int<1>{}), "Nullspace definition error"); + CUTE_STATIC_ASSERT_V((cosize<0,1>(src_n.layout()) == Int<1>{}), "Error: Ambiguous scatter detected in copy"); + CUTE_STATIC_ASSERT_V((size<1,0>(dst_n) == Int<1>{})); + CUTE_STATIC_ASSERT_V((size<1,0>(src_n) == Int<1>{})); + + Tensor dst_c = dst_n(make_coord(_,Int<0>{}),make_coord(Int<0>{},_)); // (V, Rest) + Tensor src_c = src_n(make_coord(_,Int<0>{}),make_coord(Int<0>{},_)); // (V, Rest) + + CUTE_STATIC_ASSERT_V(size<1>(src_c) == size<1>(dst_c)); + CUTE_STATIC_ASSERT_V(shape<0>(dst_c) == shape<0>(dst)); + CUTE_STATIC_ASSERT_V(shape<0>(src_c) == shape<0>(src)); + + CUTE_UNROLL + for (int i = 0; i < size<1>(dst_c); ++i) { + copy_atom.call(src_c(_,i), dst_c(_,i)); + } + } else { + CUTE_UNROLL + for (int i = 0; i < size<1>(dst_v); ++i) { + copy_atom.call(src_v(_,i), dst_v(_,i)); + } + } + } +} + +//////////////////////////////////////////////////////// +// Special Auto-Vectorizing, Auto-Filtering Overloads // +//////////////////////////////////////////////////////// + +// Specialization for AutoVectorizingCopyAssumedAlignment +template +CUTE_HOST_DEVICE +void +copy(AutoVectorizingCopyWithAssumedAlignment const&, + Tensor const& src, + Tensor & dst) +{ + constexpr int common_elem = CUTE_STATIC_V(max_common_vector(src, dst)); + constexpr int align_bits = CUTE_STATIC_V(gcd(max_alignment(src), max_alignment(dst), Int{})); + static_assert(is_integral{} * sizeof_bits_v)>::value, "Error: Attempting a subbit copy!"); + constexpr int vec_bits = gcd(common_elem * sizeof_bits_v, align_bits); + + if constexpr (common_elem > 1 && ((vec_bits % 8) == 0)) { + // If more than one element vectorizes to 8bits or more, then recast and copy + using VecType = uint_bit_t; + // Preserve volatility + using SrcVecType = conditional_t, VecType const volatile, VecType const>; + using DstVecType = conditional_t, VecType volatile, VecType >; + + // Recast + Tensor src_v = recast(src); + Tensor dst_v = recast(dst); + +#if 0 + if (thread0()) { + print("copy -- found max_common_vector of %d elems and vectorization to %d bits\n", common_elem, vec_bits); + print(" "); print(src); print(" => "); print(src_v); print("\n"); + print(" "); print(dst); print(" => "); print(dst_v); print("\n"); + } +#endif + + return copy_if(TrivialPredTensor{}, src_v, dst_v); + } else { + return copy_if(TrivialPredTensor{}, src, dst); + } +} + +template +struct AutoFilter { + Base const& base; + CUTE_HOST_DEVICE AutoFilter(Base const& b) : base(b) {} +}; + +// Specialization for AutoFilter +template +CUTE_HOST_DEVICE +void +copy(AutoFilter const& copy_op, + Tensor const& src, + Tensor & dst) +{ + if constexpr (is_constant::value) { + auto dst_null = nullspace(dst.layout()); + + Tensor dst_n = zipped_divide(dst, dst_null); + Tensor src_n = zipped_divide(src, dst_null); + + CUTE_STATIC_ASSERT_V(cosize<0>(dst_n.layout()) == Int<1>{}, "Nullspace definition error"); + CUTE_STATIC_ASSERT_V(cosize<0>(src_n.layout()) == Int<1>{}, "Error: Ambiguous scatter detected in copy"); + + copy(copy_op.base, src_n(Int<0>{},_), dst_n(Int<0>{},_)); + } else { + copy(copy_op.base, src, dst); + } +} + +// Auto-vectorizing copy for static layouts +template +CUTE_HOST_DEVICE +void +copy(Tensor const& src, + Tensor & dst) +{ + if constexpr (is_static::value && is_static::value) { + // Assume Tensors with static layouts (e.g. registers) have pointers that are 128b aligned + return copy(AutoFilter(AutoVectorizingCopyWithAssumedAlignment<128>{}), src, dst); + } else + if constexpr (is_static::value && is_static::value) { + // Tensors with static shapes can be filtered, but do not assume that dynamic layouts are aligned. + return copy(AutoFilter(AutoVectorizingCopyWithAssumedAlignment<8>{}), src, dst); + } else { + // Do not assume that dynamic layouts are aligned. + return copy(AutoVectorizingCopyWithAssumedAlignment<8>{}, src, dst); + } +} + +// Auto-vectorizing copy with assumed alignment up to 128bit. +template +CUTE_HOST_DEVICE +void +copy_aligned(Tensor const& src, + Tensor & dst) +{ + if constexpr (is_static::value && is_static::value) { + // Tensors with static shapes can be filtered + return copy(AutoFilter(AutoVectorizingCopyWithAssumedAlignment<128>{}), src, dst); + } else { + return copy(AutoVectorizingCopyWithAssumedAlignment<128>{}, src, dst); + } +} + +// Specializaton for Atom AutoVectorizingCopyAssumedAlignment +template +CUTE_HOST_DEVICE +void +copy(Copy_Atom, Args...> const&, + Tensor const& src, + Tensor & dst) +{ + return copy(AutoVectorizingCopyWithAssumedAlignment{}, src, dst); +} + +#if defined(CUTE_COPY_ATOM_TMA_SM90_ENABLED) +template +CUTE_HOST_DEVICE +void +copy(Copy_Traits const& atom, // Copy_Traits may or may not have the memory barrier in it already + Tensor const& src, + Tensor & dst) +{ + using SrcType = typename SrcEngine::value_type; + using DstType = typename DstEngine::value_type; + static_assert(cute::is_same::value); + static_assert((is_gmem::value && is_smem::value) || + (is_smem::value && is_gmem::value), + "Bulk Copy only supports gmem -> smem or smem -> gmem movement."); + // G2S or S2G dispatch + using BULK_COPY_OP = conditional_t::value, + SM90_BULK_COPY_G2S, + SM90_BULK_COPY_S2G>; + + // Find the common subtensor of src and dst + auto tiler = max_common_layout(src, dst); + constexpr int vec_elem = decltype(size(tiler))::value; + constexpr int vec_bits = vec_elem * sizeof_bits_v; + static_assert(vec_bits >= 128, "Expected at least 128-bits for BLKCP"); + + // Construct a new concrete Atom of the vector size + using BulkAtom = Copy_Atom, CT_Args...>, SrcType>; + auto bulk_atom = apply(atom.opargs_, [](auto const&... args) { return BulkAtom{args...}; }); + +#if 0 + if (thread0()) { + print("copy blkcp -- found a max_common_layout of "); print(tiler); print("\n"); + print(" "); print(src); print("\n"); + print(" "); print(dst); print("\n"); + } +#endif + + return copy(bulk_atom, logical_divide(src, tiler), logical_divide(dst, tiler)); +} + +// Backwards-compat. Throw out any extra Copy_Atom args. +template +CUTE_HOST_DEVICE +void +copy(Copy_Atom, CA_Args...> const& atom, + Tensor const& src, + Tensor & dst) +{ + return copy(static_cast const&>(atom), src, dst); +} +#endif // #if defined(CUTE_COPY_ATOM_TMA_SM90_ENABLED) + +// +// Decay TiledCopy to CopyAtom +// + +template +CUTE_HOST_DEVICE +void +copy_if(TiledCopy const& tiled_copy, + PrdTensor const& pred, + Tensor const& src, + Tensor & dst) +{ + return copy_if(static_cast(tiled_copy), pred, src, dst); +} + +template +CUTE_HOST_DEVICE +void +copy(TiledCopy const& tiled_copy, + Tensor const& src, + Tensor & dst) +{ + return copy(static_cast(tiled_copy), src, dst); +} + +template +CUTE_HOST_DEVICE +void +copy_if(ThrCopy const& thr_copy, + PrdTensor const& pred, + Tensor const& src, + Tensor & dst) = delete; + +template +CUTE_HOST_DEVICE +void +copy(ThrCopy const& thr_copy, + Tensor const& src, + Tensor & dst) = delete; + +// +// Catch uncaught policies +// + +template +CUTE_HOST_DEVICE +void +copy_if(CopyPolicy const& cpy, + PredTensor const& prd, + Tensor const& src, + Tensor & dst) +{ + static_assert(dependent_false, "Unrecognized CopyPolicy."); +} + +template +CUTE_HOST_DEVICE +void +copy(CopyPolicy const& cpy, + Tensor const& src, + Tensor & dst) +{ + static_assert(dependent_false, "Unrecognized CopyPolicy."); +} + +// +// Accept mutable temporaries +// + +template +CUTE_HOST_DEVICE +void +copy_if(PrdTensor const& pred, + Tensor const& src, + Tensor && dst) +{ + return copy_if(pred, src, dst); +} + +template +CUTE_HOST_DEVICE +void +copy_if(CopyPolicy const& copy_policy, + PrdTensor const& pred, + Tensor const& src, + Tensor && dst) +{ + return copy_if(copy_policy, pred, src, dst); +} + +template +CUTE_HOST_DEVICE +void +copy(Tensor const& src, + Tensor && dst) +{ + return copy(src, dst); +} + +template +CUTE_HOST_DEVICE +void +copy(CopyPolicy const& copy_policy, + Tensor const& src, + Tensor && dst) +{ + return copy(copy_policy, src, dst); +} + +template +CUTE_HOST_DEVICE +void +copy_aligned(Tensor const& src, + Tensor && dst) +{ + return copy_aligned(src, dst); +} + +} // end namespace cute diff --git a/include/cute/algorithm/fill.hpp b/include/cute/algorithm/fill.hpp new file mode 100644 index 0000000000..3f33a42ade --- /dev/null +++ b/include/cute/algorithm/fill.hpp @@ -0,0 +1,87 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include + +namespace cute +{ + +// +// Accept mutable temporaries +// +template +CUTE_HOST_DEVICE +void +fill(Tensor&& tensor, T const& value) +{ + return fill(tensor, value); +} + +namespace detail +{ + +// Prefer fill(tensor.data(), value), if possible +template +CUTE_HOST_DEVICE +auto +fill(Tensor& tensor, T const& value, prefer<1>) + -> decltype(fill(tensor.data(), value)) +{ + fill(tensor.data(), value); +} + +// Default implementation +template +CUTE_HOST_DEVICE +void +fill(Tensor& tensor, T const& value, prefer<0>) +{ + CUTE_UNROLL + for (int i = 0; i < size(tensor); ++i) { + tensor(i) = value; + } +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE +void +fill(Tensor& tensor, T const& value) +{ + return detail::fill(tensor, value, prefer<1>{}); +} + +} // end namespace cute diff --git a/include/cute/algorithm/functional.hpp b/include/cute/algorithm/functional.hpp new file mode 100644 index 0000000000..ef80d018d7 --- /dev/null +++ b/include/cute/algorithm/functional.hpp @@ -0,0 +1,290 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTE_HOST_DEVICE +#include // cute::max, cute::min +#include // cute::conj + +/** C++14 extensions */ + +namespace cute { + +/**************/ +/** Identity **/ +/**************/ + +struct identity { + template + CUTE_HOST_DEVICE constexpr + decltype(auto) operator()(T&& arg) const { + return static_cast(arg); + } +}; + +template +struct constant_fn { + template + CUTE_HOST_DEVICE constexpr + decltype(auto) operator()(T&&...) const { + return r_; + } + R r_; +}; + +/***********/ +/** Unary **/ +/***********/ + +#define CUTE_LEFT_UNARY_OP(NAME,OP) \ + struct NAME { \ + template \ + CUTE_HOST_DEVICE constexpr \ + decltype(auto) operator()(T&& arg) const { \ + return OP static_cast(arg); \ + } \ + } +#define CUTE_RIGHT_UNARY_OP(NAME,OP) \ + struct NAME { \ + template \ + CUTE_HOST_DEVICE constexpr \ + decltype(auto) operator()(T&& arg) const { \ + return static_cast(arg) OP ; \ + } \ + } +#define CUTE_NAMED_UNARY_OP(NAME,OP) \ + struct NAME { \ + template \ + CUTE_HOST_DEVICE constexpr \ + decltype(auto) operator()(T&& arg) const { \ + return OP (static_cast(arg)); \ + } \ + } + +CUTE_LEFT_UNARY_OP(unary_plus, +); +CUTE_LEFT_UNARY_OP(negate, -); +CUTE_LEFT_UNARY_OP(bit_not, ~); +CUTE_LEFT_UNARY_OP(logical_not, !); +CUTE_LEFT_UNARY_OP(dereference, *); +CUTE_LEFT_UNARY_OP(address_of, &); +CUTE_LEFT_UNARY_OP(pre_increment, ++); +CUTE_LEFT_UNARY_OP(pre_decrement, --); + +CUTE_RIGHT_UNARY_OP(post_increment, ++); +CUTE_RIGHT_UNARY_OP(post_decrement, --); + +CUTE_NAMED_UNARY_OP(abs_fn, abs); +CUTE_NAMED_UNARY_OP(conjugate, cute::conj); + +#undef CUTE_LEFT_UNARY_OP +#undef CUTE_RIGHT_UNARY_OP +#undef CUTE_NAMED_UNARY_OP + +template +struct shift_right_const { + static constexpr int Shift = Shift_; + + template + CUTE_HOST_DEVICE constexpr + decltype(auto) operator()(T&& arg) const { + return static_cast(arg) >> Shift; + } +}; + +template +struct shift_left_const { + static constexpr int Shift = Shift_; + + template + CUTE_HOST_DEVICE constexpr + decltype(auto) operator()(T&& arg) const { + return static_cast(arg) << Shift; + } +}; + +/************/ +/** Binary **/ +/************/ + +#define CUTE_BINARY_OP(NAME,OP) \ + struct NAME { \ + template \ + CUTE_HOST_DEVICE constexpr \ + decltype(auto) operator()(T&& lhs, U&& rhs) const { \ + return static_cast(lhs) OP static_cast(rhs); \ + } \ + } +#define CUTE_NAMED_BINARY_OP(NAME,OP) \ + struct NAME { \ + template \ + CUTE_HOST_DEVICE constexpr \ + decltype(auto) operator()(T&& lhs, U&& rhs) const { \ + return OP (static_cast(lhs), static_cast(rhs)); \ + } \ + } + + +CUTE_BINARY_OP(plus, +); +CUTE_BINARY_OP(minus, -); +CUTE_BINARY_OP(multiplies, *); +CUTE_BINARY_OP(divides, /); +CUTE_BINARY_OP(modulus, %); + +CUTE_BINARY_OP(plus_assign, +=); +CUTE_BINARY_OP(minus_assign, -=); +CUTE_BINARY_OP(multiplies_assign, *=); +CUTE_BINARY_OP(divides_assign, /=); +CUTE_BINARY_OP(modulus_assign, %=); + +CUTE_BINARY_OP(bit_and, &); +CUTE_BINARY_OP(bit_or, |); +CUTE_BINARY_OP(bit_xor, ^); +CUTE_BINARY_OP(left_shift, <<); +CUTE_BINARY_OP(right_shift, >>); + +CUTE_BINARY_OP(bit_and_assign, &=); +CUTE_BINARY_OP(bit_or_assign, |=); +CUTE_BINARY_OP(bit_xor_assign, ^=); +CUTE_BINARY_OP(left_shift_assign, <<=); +CUTE_BINARY_OP(right_shift_assign, >>=); + +CUTE_BINARY_OP(logical_and, &&); +CUTE_BINARY_OP(logical_or, ||); + +CUTE_BINARY_OP(equal_to, ==); +CUTE_BINARY_OP(not_equal_to, !=); +CUTE_BINARY_OP(greater, >); +CUTE_BINARY_OP(less, <); +CUTE_BINARY_OP(greater_equal, >=); +CUTE_BINARY_OP(less_equal, <=); + +CUTE_NAMED_BINARY_OP(max_fn, cute::max); +CUTE_NAMED_BINARY_OP(min_fn, cute::min); + +#undef CUTE_BINARY_OP +#undef CUTE_NAMED_BINARY_OP + +/**********/ +/** Fold **/ +/**********/ + +#define CUTE_FOLD_OP(NAME,OP) \ + struct NAME##_unary_rfold { \ + template \ + CUTE_HOST_DEVICE constexpr \ + auto operator()(T&&... t) const { \ + return (t OP ...); \ + } \ + }; \ + struct NAME##_unary_lfold { \ + template \ + CUTE_HOST_DEVICE constexpr \ + auto operator()(T&&... t) const { \ + return (... OP t); \ + } \ + }; \ + struct NAME##_binary_rfold { \ + template \ + CUTE_HOST_DEVICE constexpr \ + auto operator()(U&& u, T&&... t) const { \ + return (t OP ... OP u); \ + } \ + }; \ + struct NAME##_binary_lfold { \ + template \ + CUTE_HOST_DEVICE constexpr \ + auto operator()(U&& u, T&&... t) const { \ + return (u OP ... OP t); \ + } \ + } + +CUTE_FOLD_OP(plus, +); +CUTE_FOLD_OP(minus, -); +CUTE_FOLD_OP(multiplies, *); +CUTE_FOLD_OP(divides, /); +CUTE_FOLD_OP(modulus, %); + +CUTE_FOLD_OP(plus_assign, +=); +CUTE_FOLD_OP(minus_assign, -=); +CUTE_FOLD_OP(multiplies_assign, *=); +CUTE_FOLD_OP(divides_assign, /=); +CUTE_FOLD_OP(modulus_assign, %=); + +CUTE_FOLD_OP(bit_and, &); +CUTE_FOLD_OP(bit_or, |); +CUTE_FOLD_OP(bit_xor, ^); +CUTE_FOLD_OP(left_shift, <<); +CUTE_FOLD_OP(right_shift, >>); + +CUTE_FOLD_OP(bit_and_assign, &=); +CUTE_FOLD_OP(bit_or_assign, |=); +CUTE_FOLD_OP(bit_xor_assign, ^=); +CUTE_FOLD_OP(left_shift_assign, <<=); +CUTE_FOLD_OP(right_shift_assign, >>=); + +CUTE_FOLD_OP(logical_and, &&); +CUTE_FOLD_OP(logical_or, ||); + +CUTE_FOLD_OP(equal_to, ==); +CUTE_FOLD_OP(not_equal_to, !=); +CUTE_FOLD_OP(greater, >); +CUTE_FOLD_OP(less, <); +CUTE_FOLD_OP(greater_equal, >=); +CUTE_FOLD_OP(less_equal, <=); + +#undef CUTE_FOLD_OP + +/**********/ +/** Meta **/ +/**********/ + +template +struct bound_fn { + + template + CUTE_HOST_DEVICE constexpr + decltype(auto) + operator()(T&& arg) { + return fn_(arg_, static_cast(arg)); + } + + Fn fn_; + Arg arg_; +}; + +template +CUTE_HOST_DEVICE constexpr +auto +bind(Fn const& fn, Arg const& arg) { + return bound_fn{fn, arg}; +} + +} // end namespace cute diff --git a/include/cute/algorithm/gemm.hpp b/include/cute/algorithm/gemm.hpp new file mode 100644 index 0000000000..c4713838b6 --- /dev/null +++ b/include/cute/algorithm/gemm.hpp @@ -0,0 +1,500 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include + +#include + +#include + +/** The gemm algorithm takes four (or three) tensors and computes + * D = A * B + C + * It dispatches based on the number of modes each tensor has: + * + * 1. `(V) x (V) => (V)`. + * The element-wise product of vectors. Dispatches to FMA or MMA. + * 2. `(M) x (N) => (M,N)`. + * The outer product of vectors. Dispatches to [3] with new mode K=(1). + * 3. `(M,K) x (N,K) => (M,N)`. + * The product of matrices. Dispatches to [5] with MMA vector-mode V. + * 4. `(V,M) x (V,N) => (V,M,N)`. + * The batched outer product of vectors. Accounts for register reuse and dispatches to [1] for each (m,n). + * 5. `(V,M,K) x (V,N,K) => (V,M,N)`. + * The batched product of matrices. Dispatches to [4] for each (k). + */ + +namespace cute +{ + +// +// Three arguments to four +// + +template +CUTE_HOST_DEVICE +void +gemm(Tensor const& A, + Tensor const& B, + Tensor & C) +{ + return gemm(C, A, B, C); +} + +template +CUTE_HOST_DEVICE +void +gemm(MMA_Atom const& mma, + Tensor const& A, + Tensor const& B, + Tensor & C) +{ + return gemm(mma, C, A, B, C); +} + +// +// Accept mutable temporaries +// + +template +CUTE_HOST_DEVICE +void +gemm(Tensor const& A, + Tensor const& B, + Tensor && C) +{ + return gemm(C, A, B, C); +} + +template +CUTE_HOST_DEVICE +void +gemm(Tensor && D, + Tensor const& A, + Tensor const& B, + Tensor const& C) +{ + return gemm(D, A, B, C); +} + +template +CUTE_HOST_DEVICE +void +gemm(MMA_Atom const& mma, + Tensor const& A, + Tensor const& B, + Tensor && C) +{ + return gemm(mma, C, A, B, C); +} + +template +CUTE_HOST_DEVICE +void +gemm(MMA_Atom const& mma, + Tensor && D, + Tensor const& A, + Tensor const& B, + Tensor const& C) +{ + return gemm(mma, D, A, B, C); +} + +// +// Default MMA is UniversalFMA +// + +template +CUTE_HOST_DEVICE +void +gemm(Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) +{ + using MMA = MMA_Atom::value_type, + typename Tensor::value_type, + typename Tensor::value_type, + typename Tensor::value_type>>; + + return gemm(MMA{}, D, A, B, C); +} + +// +// Thread-Local Register-Memory GEMMs +// + +// Dispatch [1]: (V) x (V) => (V) +template ::value && + ALayout::rank == 1 && is_rmem::value && + BLayout::rank == 1 && is_rmem::value && + CLayout::rank == 1 && is_rmem::value)> +CUTE_HOST_DEVICE +void +gemm(MMA_Atom const& mma, + Tensor & D, // (V) Logical data + Tensor const& A, // (V) Logical data + Tensor const& B, // (V) Logical data + Tensor const& C) // (V) Logical data +{ + // No static assertions on (V), MMA checks compatibility + mma.call(D, A, B, C); +} + +// Dispatch [2]: (M) x (N) => (M,N) +template ::value && + ALayout::rank == 1 && is_rmem::value && + BLayout::rank == 1 && is_rmem::value && + CLayout::rank == 2 && is_rmem::value)> +CUTE_HOST_DEVICE +void +gemm(MMA_Atom const& mma, + Tensor & D, // (M,N) Logical data + Tensor const& A, // (M) Logical data + Tensor const& B, // (N) Logical data + Tensor const& C) // (M,N) Logical data +{ + CUTE_STATIC_ASSERT_V(size<0>(A) == size<0>(C)); // AM == CM + CUTE_STATIC_ASSERT_V(size<0>(B) == size<1>(C)); // BN == CN + CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D)); + gemm(mma, + D, // (M,N) + make_tensor(A.data(), append<2>(A.layout())), // (M,1) + make_tensor(B.data(), append<2>(B.layout())), // (N,1) + C); // (M,N) +} + +// Dispatch [3]: (M,K) x (N,K) => (M,N) +template ::value && + ALayout::rank == 2 && is_rmem::value && + BLayout::rank == 2 && is_rmem::value && + CLayout::rank == 2 && is_rmem::value)> +CUTE_HOST_DEVICE +void +gemm(MMA_Atom const& mma, + Tensor & D, // (M,N) Logical data + Tensor const& A, // (M,K) Logical data + Tensor const& B, // (N,K) Logical data + Tensor const& C) // (M,N) Logical data +{ + CUTE_STATIC_ASSERT_V(size<0>(A) == size<0>(C)); // AM == CM + CUTE_STATIC_ASSERT_V(size<0>(B) == size<1>(C)); // BN == CN + CUTE_STATIC_ASSERT_V(size<1>(A) == size<1>(B)); // AK == BK + CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D)); + + // Assert this is a 1-value MMA + CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom::LayoutC_TV{}) == Int<1>{}); + CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom::LayoutA_TV{}) == Int<1>{}); + CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom::LayoutB_TV{}) == Int<1>{}); + + gemm(mma, + make_tensor(D.data(), prepend<3>(D.layout())), // (1,M,N) + make_tensor(A.data(), prepend<3>(A.layout())), // (1,M,K) + make_tensor(B.data(), prepend<3>(B.layout())), // (1,N,K) + make_tensor(C.data(), prepend<3>(C.layout()))); // (1,M,N) +} + +// Dispatch [4]: (V,M) x (V,N) => (V,M,N) +template ::value && + ALayout::rank == 2 && is_rmem::value && + BLayout::rank == 2 && is_rmem::value && + CLayout::rank == 3 && is_rmem::value)> +CUTE_HOST_DEVICE +void +gemm(MMA_Atom const& mma, + Tensor & D, // (V,M,N) Logical data + Tensor const& A, // (V,M) Logical data + Tensor const& B, // (V,N) Logical data + Tensor const& C) // (V,M,N) Logical data +{ + CUTE_STATIC_ASSERT_V(size<1>(A) == size<1>(C)); // AM == CM + CUTE_STATIC_ASSERT_V(size<1>(B) == size<2>(C)); // BN == CN + CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D) && size<2>(C) == size<2>(D)); + auto M = size<1>(A); + auto N = size<1>(B); + // REGISTER .reuse OPTIMIZATIONS + // 64-bit traversal specialization -- serpentine path + if constexpr (decltype(size<0>(A))::value * sizeof(typename TA::value_type) == 8 && + decltype(size<0>(B))::value * sizeof(typename TB::value_type) == 8) + { +#if 1 // NOTE: Row- vs Col- major could depend on the C-matrix order... (which we can test) + // Row-major serpentine iteration + CUTE_UNROLL + for (int m = 0; m < M; ++m) { + CUTE_UNROLL + for (int n = 0; n < N; ++n) { + int ns = (m & 1) ? N-1-n : n; // Serpentine coordinate + gemm(mma, D(_,m,ns), A(_,m), B(_,ns), C(_,m,ns)); + } + } +#else + // Col-major serpentine iteration + CUTE_UNROLL + for (int n = 0; n < N; ++n) { + CUTE_UNROLL + for (int m = 0; m < M; ++m) { + int ms = (n & 1) ? M-1-m : m; // Serpentine coordinate + gemm(mma, D(_,ms,n), A(_,ms), B(_,n), C(_,ms,n)); + } + } +#endif + } else + // 32-bit traversal specialization -- kinked serpentine path + if constexpr (decltype(size<0>(A))::value * sizeof(typename TA::value_type) == 4 && + decltype(size<0>(B))::value * sizeof(typename TB::value_type) == 4) + { +#if 1 // NOTE: Row- vs Col- major could depend on the C-matrix order... (which we can test) + // Row-major kinked serpentine iteration + CUTE_UNROLL + for (int m = 0; m < M; m += 2) { + CUTE_UNROLL + for (int n = 0; n < N; ++n) { + int ns = (m & 2) ? N-1-n : n; + gemm(mma, D(_,m+0,ns), A(_,m+0), B(_,ns), C(_,m+0,ns)); + + if (m+1 < M) { + gemm(mma, D(_,m+1,ns), A(_,m+1), B(_,ns), C(_,m+1,ns)); + } + } + } +#else + // Col-major kinked serpentine iteration + CUTE_UNROLL + for (int n = 0; n < N; n += 2) { + CUTE_UNROLL + for (int m = 0; m < M; ++m) { + // Kinked serpentine traversal for maximum register reuse + int ms = (n & 2) ? M-1-m : m; + gemm(mma, D(_,ms,n+0), A(_,ms), B(_,n+0), C(_,ms,n+0)); + + if (n+1 < N) { + gemm(mma, D(_,ms,n+1), A(_,ms), B(_,n+1), C(_,ms,n+1)); + } + } + } +#endif + } else + // 64-bit + 32-bit traversal order -- keep A (64-bit) in the outer loop and serpentine B + if constexpr (decltype(size<0>(A))::value * sizeof(typename TA::value_type) == 8 && + decltype(size<0>(B))::value * sizeof(typename TB::value_type) == 4) { + // Row-major serpentine iteration + CUTE_UNROLL + for (int m = 0; m < M; ++m) { + CUTE_UNROLL + for (int n = 0; n < N; ++n) { + int ns = (m & 1) ? N-1-n : n; // Serpentine coordinate + gemm(mma, D(_,m,ns), A(_,m), B(_,ns), C(_,m,ns)); + } + } + } else + // 32-bit + 64-bit traversal order -- keep B (64-bit) in the outer loop and serpentine A + if constexpr (decltype(size<0>(A))::value * sizeof(typename TA::value_type) == 4 && + decltype(size<0>(B))::value * sizeof(typename TB::value_type) == 8) { + // Col-major serpentine iteration + CUTE_UNROLL + for (int n = 0; n < N; ++n) { + CUTE_UNROLL + for (int m = 0; m < M; ++m) { + int ms = (n & 1) ? M-1-m : m; // Serpentine coordinate + gemm(mma, D(_,ms,n), A(_,ms), B(_,n), C(_,ms,n)); + } + } + } else + // Fallback to serpentine loop + { + // Col-major serpentine iteration + CUTE_UNROLL + for (int n = 0; n < N; ++n) { + CUTE_UNROLL + for (int m = 0; m < M; ++m) { + int ms = (n & 1) ? M-1-m : m; // Serpentine coordinate + gemm(mma, D(_,ms,n), A(_,ms), B(_,n), C(_,ms,n)); + } + } + } +} + +// Dispatch [5]: (V,M,K) x (V,N,K) => (V,M,N) +template ::value && + ALayout::rank == 3 && is_rmem::value && + BLayout::rank == 3 && is_rmem::value && + CLayout::rank == 3 && is_rmem::value)> +CUTE_HOST_DEVICE +void +gemm(MMA_Atom const& mma, + Tensor & D, // (V,M,N) Logical data + Tensor const& A, // (V,M,K) Logical data + Tensor const& B, // (V,N,K) Logical data + Tensor const& C) // (V,M,N) Logical data +{ + CUTE_STATIC_ASSERT_V(size<1>(A) == size<1>(C)); // AM == CM + CUTE_STATIC_ASSERT_V(size<1>(B) == size<2>(C)); // BN == CN + CUTE_STATIC_ASSERT_V(size<2>(A) == size<2>(B)); // AK == BK + CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D) && size<2>(C) == size<2>(D)); + auto K = size<2>(A); + + CUTE_UNROLL + for (int k = 0; k < K; ++k) { + gemm(mma, D, A(_,_,k), B(_,_,k), C); + } +} + +// +// Thread-Local Shared-Memory GEMMs +// + +// Dispatch [1]: (V) x (V) => (V) +// Dispatch [2]: (M) x (N) => (M,N) +// Dispatch [3]: (M,K) x (N,K) => (M,N) +// Dispatch [4]: (V,M) x (V,N) => (V,M,N) +// Dispatch [5]: (V,M,K) x (V,N,K) => (V,M,N) +// Dispatch [3]: (M,K) x (N,K) => (M,N) +template ::value && + ALayout::rank == 2 && is_smem::value && + BLayout::rank == 2 && is_smem::value && + CLayout::rank == 2 && is_rmem::value)> +CUTE_HOST_DEVICE +void +gemm(MMA_Atom const& mma, + Tensor & D, // (M,N) Logical data + Tensor const& A, // (M,K) Logical data + Tensor const& B, // (N,K) Logical data + Tensor const& C) // (M,N) Logical data +{ + CUTE_STATIC_ASSERT_V(size<0>(A) == size<0>(C)); // AM == CM + CUTE_STATIC_ASSERT_V(size<0>(B) == size<1>(C)); // BN == CN + CUTE_STATIC_ASSERT_V(size<1>(A) == size<1>(B)); // AK == BK + CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D)); + + // Assert this is a 1-value MMA + CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom::LayoutC_TV{}) == Int<1>{}); + CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom::LayoutA_TV{}) == Int<1>{}); + CUTE_STATIC_ASSERT_V(size<1>(typename MMA_Atom::LayoutB_TV{}) == Int<1>{}); + + gemm(mma, + make_tensor(D.data(), prepend<3>(D.layout())), // (1,M,N) + make_tensor(A.data(), prepend<3>(A.layout())), // (1,M,K) + make_tensor(B.data(), prepend<3>(B.layout())), // (1,N,K) + make_tensor(C.data(), prepend<3>(C.layout()))); // (1,M,N) +} + +// Dispatch [5]: (V,M,K) x (V,N,K) => (V,M,N) +template ::value && + ALayout::rank == 3 && is_smem::value && + BLayout::rank == 3 && is_smem::value && + CLayout::rank == 3 && is_rmem::value)> +CUTE_HOST_DEVICE +void +gemm(MMA_Atom const& mma, + Tensor & D, // (V,M,N) Logical data + Tensor const& A, // (V,M,K) Logical data + Tensor const& B, // (V,N,K) Logical data + Tensor const& C) // (V,M,N) Logical data +{ + CUTE_STATIC_ASSERT_V(size<1>(A) == size<1>(C)); // AM == CM + CUTE_STATIC_ASSERT_V(size<1>(B) == size<2>(C)); // BN == CN + CUTE_STATIC_ASSERT_V(size<2>(A) == size<2>(B)); // AK == BK + CUTE_STATIC_ASSERT_V(size<0>(C) == size<0>(D) && size<1>(C) == size<1>(D) && size<2>(C) == size<2>(D)); + + auto rA = MMA_Atom::make_fragment_A(A); + auto rB = MMA_Atom::make_fragment_B(B); + + auto K = size<2>(A); + + CUTE_UNROLL + for (int k = 0; k < K; ++k) + { + copy(A(_,_,k), rA(_,_,k)); + copy(B(_,_,k), rB(_,_,k)); + // Thread-level register gemm for k + gemm(mma, D, rA(_,_,k), rB(_,_,k), C); + } +} + +} // end namespace cute diff --git a/include/cute/algorithm/prefer.hpp b/include/cute/algorithm/prefer.hpp new file mode 100644 index 0000000000..a69e504298 --- /dev/null +++ b/include/cute/algorithm/prefer.hpp @@ -0,0 +1,46 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +namespace cute +{ + +// Infinite types that inherit from each other +template +struct prefer : prefer {}; + +template <> +struct prefer<0> {}; + +// Can be used to preferencially overload implementations +// Higher N in prefer have higher priority. + +} // end namespace cute diff --git a/include/cute/algorithm/prefetch.hpp b/include/cute/algorithm/prefetch.hpp new file mode 100644 index 0000000000..c39f63acdd --- /dev/null +++ b/include/cute/algorithm/prefetch.hpp @@ -0,0 +1,145 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTE_HOST_DEVICE +#include // cute::Tensor +#include // cute::Copy_Atom + +namespace cute +{ + +// +// Prefetch global tensors into L2 +// + +template +CUTE_HOST_DEVICE +void +cooperative_prefetch(uint32_t const& tid, + Tensor const& src) +{ + static_assert(is_gmem::value, "Expected global tensor for prefetch"); + + constexpr int V = decltype(max_common_vector(src, src))::value; + + if constexpr (V > 1) { + // L2 sector is 32B, default fetch granularity is 64B + using VecType = conditional_t<(V * sizeof_bits_v) < (FetchBytes * 8), + ArrayEngine, + uint8_t[FetchBytes] >; + + Tensor src_v = recast(src); + CUTE_UNROLL + for (int i = tid; i < size(src_v); i += NumThreads) { + prefetch(raw_pointer_cast(&src_v(i))); + } + } else { + CUTE_UNROLL + for (int i = tid; i < size(src); i += NumThreads) { + prefetch(raw_pointer_cast(&src(i))); + } + } +} + +template +CUTE_HOST_DEVICE +void +prefetch(Tensor const& src) +{ + return cooperative_prefetch<1>(0, src); +} + +// Prefetch with copy atom +namespace detail { + +template +constexpr bool has_prefetch = false; + +template +constexpr bool has_prefetch> = true; + +} // end namespace detail + +template +CUTE_HOST_DEVICE +void +prefetch(Copy_Atom, CA_Args...> const& atom, + Tensor const& src) +{ + if constexpr (detail::has_prefetch) { + using Prefetch_Traits = Copy_Traits; + using Prefetch_Atom = Copy_Atom; + Prefetch_Atom prefetch_atom{atom}; + auto& dst = const_cast&>(src); // dst is ignored for prefetch atoms + return copy(prefetch_atom, src, dst); + } else { + return prefetch(src); + } +} + +#if defined(CUTE_COPY_ATOM_TMA_SM90_ENABLED) +template +CUTE_HOST_DEVICE +void +prefetch(Copy_Traits const& atom, + Tensor const& src) +{ + using SrcType = typename SrcEngine::value_type; + static_assert(is_gmem::value, "Expected global tensor for L2 prefetch"); + + auto tiler = max_common_layout(src, src); + constexpr int vec_elem = decltype(size(tiler))::value; + constexpr int vec_bits = vec_elem * sizeof_bits_v; + static_assert(vec_bits >= 128, "Expected at least 128-bits for BLKCP"); + + // Construct a new concrete Atom of the vector size + auto bulk_atom = Copy_Atom>, SrcType>{}; + + return prefetch(bulk_atom, logical_divide(src, tiler)); +} + +// Backwards-compat. Throw out any extra Copy_Atom args. +template +CUTE_HOST_DEVICE +void +prefetch(Copy_Atom, CA_Args...> const& atom, + Tensor const& src) +{ + return prefetch(static_cast const&>(atom), src); +} +#endif // #if defined(CUTE_COPY_ATOM_TMA_SM90_ENABLED) + +} // end namespace cute diff --git a/include/cute/algorithm/tensor_algorithms.hpp b/include/cute/algorithm/tensor_algorithms.hpp new file mode 100644 index 0000000000..dbffc61335 --- /dev/null +++ b/include/cute/algorithm/tensor_algorithms.hpp @@ -0,0 +1,166 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/** Common algorithms on (hierarchical) tensors */ + +#pragma once + +#include +#include + +namespace cute +{ + +// +// for_each +// + +template +CUTE_HOST_DEVICE constexpr +void +for_each(Tensor const& tensor, UnaryOp&& op) +{ + CUTE_UNROLL + for (int i = 0; i < size(tensor); ++i) { + op(tensor(i)); + } +} + +template +CUTE_HOST_DEVICE constexpr +void +for_each(Tensor& tensor, UnaryOp&& op) +{ + CUTE_UNROLL + for (int i = 0; i < size(tensor); ++i) { + op(tensor(i)); + } +} + +// Accept mutable temporaries +template +CUTE_HOST_DEVICE constexpr +void +for_each(Tensor&& tensor, UnaryOp&& op) +{ + return for_each(tensor, op); +} + +// +// transform +// + +// Similar to std::transform but does not return number of elements affected +template +CUTE_HOST_DEVICE constexpr +void +transform(Tensor& tensor, UnaryOp&& op) +{ + CUTE_UNROLL + for (int i = 0; i < size(tensor); ++i) { + tensor(i) = op(tensor(i)); + } +} + +// Accept mutable temporaries +template +CUTE_HOST_DEVICE constexpr +void +transform(Tensor&& tensor, UnaryOp&& op) +{ + return transform(tensor, op); +} + +// Similar to std::transform transforms one tensors and assigns it to another +template +CUTE_HOST_DEVICE constexpr +void +transform(Tensor const& tensor_in, + Tensor & tensor_out, + UnaryOp&& op) +{ + CUTE_UNROLL + for (int i = 0; i < size(tensor_in); ++i) { + tensor_out(i) = op(tensor_in(i)); + } +} + +// Accept mutable temporaries +template +CUTE_HOST_DEVICE constexpr +void +transform(Tensor const& tensor_in, + Tensor && tensor_out, + UnaryOp&& op) +{ + return transform(tensor_in, tensor_out, op); +} + +// Similar to std::transform with a binary operation +// Takes two tensors as input and one tensor as output. +// Applies the binary_op to tensor_in1 and tensor_in2 and +// assigns it to tensor_out +template +CUTE_HOST_DEVICE constexpr +void +transform(Tensor const& tensor_in1, + Tensor const& tensor_in2, + Tensor & tensor_out, + BinaryOp&& op) +{ + CUTE_UNROLL + for (int i = 0; i < size(tensor_in1); ++i) { + tensor_out(i) = op(tensor_in1(i), tensor_in2(i)); + } +} + +// Accept mutable temporaries +template +CUTE_HOST_DEVICE constexpr +void +transform(Tensor const& tensor_in1, + Tensor const& tensor_in2, + Tensor && tensor_out, + BinaryOp&& op) +{ + return transform(tensor_in1, tensor_in2, tensor_out, op); +} + +} // end namespace cute diff --git a/include/cute/algorithm/tuple_algorithms.hpp b/include/cute/algorithm/tuple_algorithms.hpp new file mode 100644 index 0000000000..5a70f590b6 --- /dev/null +++ b/include/cute/algorithm/tuple_algorithms.hpp @@ -0,0 +1,1073 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include +#include +#include +#include + +/// @file tuple_algorithms.hpp +/// @brief Common algorithms on (hierarchical) tuples +/// +/// Code guidelines and style preferences: +/// +/// For perfect forwarding, don't use std::forward, because it may not +/// be defined in device code when compiling with NVRTC. Instead, use +/// `static_cast(parameter_name)`. +/// +/// CuTe generally does not bother forwarding functions, as +/// reference-qualified member functions are rare in this code base. +/// +/// Throughout CUTLASS, cute::make_tuple always needs to be called +/// namespace-qualified, EVEN If inside the cute namespace and/or in +/// scope of a "using namespace cute" declaration. Otherwise, the +/// compiler may select std::make_tuple instead of cute::make_tuple, +/// due to argument-dependent lookup. + +namespace cute +{ + +// +// Apply (Unpack) +// (t, f) => f(t_0,t_1,...,t_n) +// + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +apply(T&& t, F&& f, seq) +{ + return f(get(static_cast(t))...); +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +apply(T&& t, F&& f) +{ + return detail::apply(static_cast(t), f, tuple_seq{}); +} + +// +// Transform Apply +// (t, f, g) => g(f(t_0),f(t_1),...) +// + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +tapply(T&& t, F&& f, G&& g, seq) +{ + return g(f(get(static_cast(t)))...); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tapply(T0&& t0, T1&& t1, F&& f, G&& g, seq) +{ + return g(f(get(static_cast(t0)), + get(static_cast(t1)))...); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tapply(T0&& t0, T1&& t1, T2&& t2, F&& f, G&& g, seq) +{ + return g(f(get(static_cast(t0)), + get(static_cast(t1)), + get(static_cast(t2)))...); +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +transform_apply(T&& t, F&& f, G&& g) +{ + if constexpr (is_tuple>::value) { + return detail::tapply(static_cast(t), f, g, tuple_seq{}); + } else { + return g(f(static_cast(t))); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +transform_apply(T0&& t0, T1&& t1, F&& f, G&& g) +{ + if constexpr (is_tuple>::value) { + return detail::tapply(static_cast(t0), static_cast(t1), f, g, tuple_seq{}); + } else { + return g(f(static_cast(t0), static_cast(t1))); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +transform_apply(T0&& t0, T1&& t1, T2&& t2, F&& f, G&& g) +{ + if constexpr (is_tuple>::value) { + return detail::tapply(static_cast(t0), static_cast(t1), static_cast(t2), f, g, tuple_seq{}); + } else { + return g(f(static_cast(t0), static_cast(t1), static_cast(t2))); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// For Each +// (t, f) => f(t_0),f(t_1),...,f(t_n) +// + +template +CUTE_HOST_DEVICE constexpr +void +for_each(T&& t, F&& f) +{ + if constexpr (is_tuple>::value) { + return detail::apply(t, [&](auto&&... a) { (f(static_cast(a)), ...); }, tuple_seq{}); + } else { + return f(static_cast(t)); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +for_each_leaf(T&& t, F&& f) +{ + if constexpr (is_tuple>::value) { + return detail::apply(static_cast(t), [&](auto&&... a){ return (for_each_leaf(static_cast(a), f), ...); }, tuple_seq{}); + } else { + return f(static_cast(t)); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Transform +// (t, f) => (f(t_0),f(t_1),...,f(t_n)) +// + +template +CUTE_HOST_DEVICE constexpr +auto +transform(T const& t, F&& f) +{ + if constexpr (is_tuple::value) { + return detail::tapply(t, f, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_seq{}); + } else { + return f(t); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +transform(T0 const& t0, T1 const& t1, F&& f) +{ + if constexpr (is_tuple::value) { + static_assert(tuple_size::value == tuple_size::value, "Mismatched tuple_size"); + return detail::tapply(t0, t1, f, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_seq{}); + } else { + return f(t0, t1); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +transform(T0 const& t0, T1 const& t1, T2 const& t2, F&& f) +{ + if constexpr (is_tuple::value) { + static_assert(tuple_size::value == tuple_size::value, "Mismatched tuple_size"); + static_assert(tuple_size::value == tuple_size::value, "Mismatched tuple_size"); + return detail::tapply(t0, t1, t2, f, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_seq{}); + } else { + return f(t0, t1, t2); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +transform_leaf(T const& t, F&& f) +{ + if constexpr (is_tuple::value) { + return transform(t, [&](auto const& a) { return transform_leaf(a, f); }); + } else { + return f(t); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +transform_leaf(T0 const& t0, T1 const& t1, F&& f) +{ + if constexpr (is_tuple::value) { + return transform(t0, t1, [&](auto const& a, auto const& b) { return transform_leaf(a, b, f); }); + } else { + return f(t0, t1); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// find and find_if +// + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +find_if(T const& t, F&& f, seq) +{ + if constexpr (decltype(f(get(t)))::value) { + return cute::C{}; + } else + if constexpr (sizeof...(Is) == 0) { + return cute::C{}; + } else { + return find_if(t, f, seq{}); + } + + CUTE_GCC_UNREACHABLE; +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +find_if(T const& t, F&& f) +{ + if constexpr (is_tuple::value) { + return detail::find_if(t, f, tuple_seq{}); + } else { + return cute::C{}; + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +find(T const& t, X const& x) +{ + return find_if(t, [&](auto const& v) { return v == x; }); // This should always return a static true/false +} + +template +CUTE_HOST_DEVICE constexpr +auto +any_of(T const& t, F&& f) +{ + if constexpr (is_tuple::value) { + return detail::apply(cute::transform(t, f), [&] (auto const&... a) { return (false_type{} || ... || a); }, tuple_seq{}); + } else { + return f(t); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +all_of(T const& t, F&& f) +{ + if constexpr (is_tuple::value) { + return detail::apply(cute::transform(t, f), [&] (auto const&... a) { return (true_type{} && ... && a); }, tuple_seq{}); + } else { + return f(t); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +none_of(T const& t, F&& f) +{ + return not any_of(t, f); +} + +// +// Filter +// (t, f) => +// + +template +CUTE_HOST_DEVICE constexpr +auto +filter_tuple(T const& t, F&& f) +{ + return transform_apply(t, f, [](auto const&... a) { return cute::tuple_cat(a...); }); +} + +template +CUTE_HOST_DEVICE constexpr +auto +filter_tuple(T0 const& t0, T1 const& t1, F&& f) +{ + return transform_apply(t0, t1, f, [](auto const&... a) { return cute::tuple_cat(a...); }); +} + +template +CUTE_HOST_DEVICE constexpr +auto +filter_tuple(T0 const& t0, T1 const& t1, T2 const& t2, F&& f) +{ + return transform_apply(t0, t1, t2, f, [](auto const&... a) { return cute::tuple_cat(a...); }); +} + +// +// Fold (Reduce, Accumulate) +// (t, v, f) => f(...f(f(v,t_0),t_1),...,t_n) +// + +namespace detail { + +template +struct FoldAdaptor { + template + CUTE_HOST_DEVICE constexpr auto operator|(X&& x) { + auto r = fn_(val_, static_cast(x)); + return FoldAdaptor{fn_, r}; + } + Fn fn_; + Val val_; +}; + +template +CUTE_HOST_DEVICE constexpr +auto +fold(T&& t, V const& v, F&& f, seq) +{ + return (FoldAdaptor{f,v} | ... | get(static_cast(t))).val_; +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +fold(T&& t, V const& v, F&& f) +{ + if constexpr (is_tuple>::value) { + return detail::fold(static_cast(t), v, f, tuple_seq{}); + } else { + return f(v, static_cast(t)); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +fold_first(T&& t, F&& f) +{ + if constexpr (is_tuple>::value) { + return detail::fold(static_cast(t), get<0>(t), f, make_range<1,tuple_size>::value>{}); + } else { + return t; + } + + CUTE_GCC_UNREACHABLE; +} + +// +// front, back, take, select, unwrap +// + +// Get the first non-tuple element in a hierarchical tuple +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +front(T&& t) +{ + if constexpr (is_tuple>::value) { + return front(get<0>(static_cast(t))); + } else { + return static_cast(t); + } + + CUTE_GCC_UNREACHABLE; +} + +// Get the last non-tuple element in a hierarchical tuple +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +back(T&& t) +{ + if constexpr (is_tuple>::value) { + constexpr int N = tuple_size>::value; + + // MSVC needs a bit of extra help here deducing return types. + // We help it by peeling off the nonrecursive case a level "early." + if constexpr (! is_tuple(static_cast(t)))>>::value) { + return get(static_cast(t)); + } else { + return back(get(static_cast(t))); + } + } else { + return static_cast(t); + } + + CUTE_GCC_UNREACHABLE; +} + +// Takes the elements in the range [B,E) +template +CUTE_HOST_DEVICE constexpr +auto +take(T const& t) +{ + if constexpr (E == -1) { + if constexpr (is_tuple::value) { + return take::value>(t); + } else { + return take(t); + } + } else + if constexpr (B <= E) { + return detail::apply(t, [](auto const&... a) { return cute::make_tuple(a...); }, make_range{}); + } else { + static_assert(B <= E); + } + + CUTE_GCC_UNREACHABLE; +} + +// Select tuple elements with given indices. +template +CUTE_HOST_DEVICE constexpr +auto +select(T const& t) +{ + return cute::make_tuple(get(t)...); +} + +// Wrap non-tuples into rank-1 tuples or forward +template +CUTE_HOST_DEVICE constexpr +auto +wrap(T const& t) +{ + if constexpr (is_tuple::value) { + return t; + } else { + return cute::make_tuple(t); + } + + CUTE_GCC_UNREACHABLE; +} + +// Unwrap rank-1 tuples until we're left with a rank>1 tuple or a non-tuple +template +CUTE_HOST_DEVICE constexpr +auto +unwrap(T const& t) +{ + if constexpr (is_tuple::value) { + if constexpr (tuple_size::value == 1) { + return unwrap(get<0>(t)); + } else { + return t; + } + } else { + return t; + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Flatten and Unflatten +// + +template +struct is_flat : true_type {}; + +template +struct is_flat> : bool_constant<(true && ... && (not is_tuple::value))> {}; + +// Flatten a hierarchical tuple to a tuple of depth one +// and wrap non-tuples into a rank-1 tuple. +template +CUTE_HOST_DEVICE constexpr +auto +flatten_to_tuple(T const& t) +{ + if constexpr (is_tuple::value) { + if constexpr (is_flat::value) { // Shortcut for perf + return t; + } else { + return filter_tuple(t, [](auto const& a) { return flatten_to_tuple(a); }); + } + } else { + return cute::make_tuple(t); + } + + CUTE_GCC_UNREACHABLE; +} + +// Flatten a hierarchical tuple to a tuple of depth one +// and leave non-tuple untouched. +template +CUTE_HOST_DEVICE constexpr +auto +flatten(T const& t) +{ + if constexpr (is_tuple::value) { + if constexpr (is_flat::value) { // Shortcut for perf + return t; + } else { + return filter_tuple(t, [](auto const& a) { return flatten_to_tuple(a); }); + } + } else { + return t; + } + + CUTE_GCC_UNREACHABLE; +} + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +unflatten_impl(FlatTuple const& flat_tuple, TargetProfile const& target_profile) +{ + if constexpr (is_tuple::value) { + return fold(target_profile, cute::make_tuple(cute::make_tuple(), flat_tuple), [](auto const& v, auto const& t) { + auto [result, remaining_tuple] = v; + auto [sub_result, sub_tuple] = unflatten_impl(remaining_tuple, t); + return cute::make_tuple(append(result, sub_result), sub_tuple); + }); + } else { + return cute::make_tuple(get<0>(flat_tuple), take<1, decltype(rank(flat_tuple))::value>(flat_tuple)); + } + + CUTE_GCC_UNREACHABLE; +} + +} // end namespace detail + +// Unflatten a flat tuple into a hierarchical tuple +// @pre flatten(@a flat_tuple) == @a flat_tuple +// @pre rank(flatten(@a target_profile)) == rank(@a flat_tuple) +// @post congruent(@a result, @a target_profile) +// @post flatten(@a result) == @a flat_tuple +template +CUTE_HOST_DEVICE constexpr +auto +unflatten(FlatTuple const& flat_tuple, TargetProfile const& target_profile) +{ + auto [unflatten_tuple, flat_remainder] = detail::unflatten_impl(flat_tuple, target_profile); + CUTE_STATIC_ASSERT_V(rank(flat_remainder) == Int<0>{}); + return unflatten_tuple; +} + +// +// insert and remove and replace +// + +namespace detail { + +// Shortcut around cute::tuple_cat for common insert/remove/repeat cases +template +CUTE_HOST_DEVICE constexpr +auto +construct(T const& t, X const& x, seq, seq, seq) +{ + return cute::make_tuple(get(t)..., (void(J),x)..., get(t)...); +} + +} // end namespace detail + +// Insert x into the Nth position of the tuple +template +CUTE_HOST_DEVICE constexpr +auto +insert(T const& t, X const& x) +{ + return detail::construct(t, x, make_seq{}, seq<0>{}, make_range::value>{}); +} + +// Remove the Nth element of the tuple +template +CUTE_HOST_DEVICE constexpr +auto +remove(T const& t) +{ + return detail::construct(t, 0, make_seq{}, seq<>{}, make_range::value>{}); +} + +// Replace the Nth element of the tuple with x +template +CUTE_HOST_DEVICE constexpr +auto +replace(T const& t, X const& x) +{ + if constexpr (is_tuple::value) { + return detail::construct(t, x, make_seq{}, seq<0>{}, make_range::value>{}); + } else { + static_assert(N == 0); + return x; + } + + CUTE_GCC_UNREACHABLE; +} + +// Replace the first element of the tuple with x +template +CUTE_HOST_DEVICE constexpr +auto +replace_front(T const& t, X const& x) +{ + if constexpr (is_tuple::value) { + return detail::construct(t, x, seq<>{}, seq<0>{}, make_range<1,tuple_size::value>{}); + } else { + return x; + } + + CUTE_GCC_UNREACHABLE; +} + +// Replace the last element of the tuple with x +template +CUTE_HOST_DEVICE constexpr +auto +replace_back(T const& t, X const& x) +{ + if constexpr (is_tuple::value) { + return detail::construct(t, x, make_seq::value-1>{}, seq<0>{}, seq<>{}); + } else { + return x; + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Make a tuple of Xs of tuple_size N +// + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_repeat(X const& x) +{ + return detail::construct(0, x, seq<>{}, make_seq{}, seq<>{}); +} + +// +// Make repeated Xs of rank N +// + +template +CUTE_HOST_DEVICE constexpr +auto +repeat(X const& x) +{ + if constexpr (N == 1) { + return x; + } else { + return detail::construct(0, x, seq<>{}, make_seq{}, seq<>{}); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Make a tuple of Xs the same profile as tuple T +// + +template +CUTE_HOST_DEVICE constexpr +auto +repeat_like(T const& t, X const& x) +{ + if constexpr (is_tuple::value) { + return transform(t, [&](auto const& a) { return repeat_like(a,x); }); + } else { + return x; + } + + CUTE_GCC_UNREACHABLE; +} + +// Group the elements [B,E) of a T into a single element +// e.g. group<2,4>(T<_1,_2,_3,_4,_5,_6>{}) +// => T<_1,_2,T<_3,_4>,_5,_6>{} +template +CUTE_HOST_DEVICE constexpr +auto +group(T const& t) +{ + if constexpr (not is_tuple::value) { + if constexpr (E == -1) { + return group(t); + } else { + return detail::construct(t, take(t), make_seq{}, make_seq<(B < E)>{}, make_range{}); + } + } else + if constexpr (E == -1) { + return group::value>(t); + } else + if constexpr (B <= E) { + return detail::construct(t, take(t), make_seq{}, make_seq<(B < E)>{}, make_range::value>{}); + } else { + static_assert(B <= E); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Extend a T to rank N by appending/prepending an element +// + +template +CUTE_HOST_DEVICE constexpr +auto +append(T const& a, X const& x) +{ + if constexpr (is_tuple::value) { + if constexpr (N == tuple_size::value) { + return a; + } else { + static_assert(N > tuple_size::value); + return detail::construct(a, x, make_seq::value>{}, make_seq::value>{}, seq<>{}); + } + } else { + if constexpr (N == 1) { + return a; + } else { + return detail::construct(cute::make_tuple(a), x, seq<0>{}, make_seq{}, seq<>{}); + } + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +append(T const& a, X const& x) +{ + if constexpr (is_tuple::value) { + return detail::construct(a, x, make_seq::value>{}, seq<0>{}, seq<>{}); + } else { + return cute::make_tuple(a, x); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +prepend(T const& a, X const& x) +{ + if constexpr (is_tuple::value) { + if constexpr (N == tuple_size::value) { + return a; + } else { + static_assert(N > tuple_size::value); + return detail::construct(a, x, seq<>{}, make_seq::value>{}, make_seq::value>{}); + } + } else { + if constexpr (N == 1) { + return a; + } else { + static_assert(N > 1); + return detail::construct(cute::make_tuple(a), x, seq<>{}, make_seq{}, seq<0>{}); + } + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +prepend(T const& a, X const& x) +{ + if constexpr (is_tuple::value) { + return detail::construct(a, x, seq<>{}, seq<0>{}, make_seq::value>{}); + } else { + return cute::make_tuple(x, a); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Inclusive scan (prefix sum) +// + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +iscan(T const& t, V const& v, F&& f, seq) +{ + // Apply the function to v and the element at I + auto v_next = f(v, get(t)); + // Replace I with v_next + auto t_next = replace(t, v_next); + +#if 0 + std::cout << "ISCAN i" << I << std::endl; + std::cout << " t " << t << std::endl; + std::cout << " i " << v << std::endl; + std::cout << " f(i,t) " << v_next << std::endl; + std::cout << " t_n " << t_next << std::endl; +#endif + + if constexpr (sizeof...(Is) == 0) { + return t_next; + } else { + return iscan(t_next, v_next, f, seq{}); + } + + CUTE_GCC_UNREACHABLE; +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +iscan(T const& t, V const& v, F&& f) +{ + return detail::iscan(t, v, f, tuple_seq{}); +} + +// +// Exclusive scan (prefix sum) +// + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +escan(T const& t, V const& v, F&& f, seq) +{ + if constexpr (sizeof...(Is) == 0) { + // Replace I with v + return replace(t, v); + } else { + // Apply the function to v and the element at I + auto v_next = f(v, get(t)); + // Replace I with v + auto t_next = replace(t, v); + +#if 0 + std::cout << "ESCAN i" << I << std::endl; + std::cout << " t " << t << std::endl; + std::cout << " i " << v << std::endl; + std::cout << " f(i,t) " << v_next << std::endl; + std::cout << " t_n " << t_next << std::endl; +#endif + + // Recurse + return escan(t_next, v_next, f, seq{}); + } + + CUTE_GCC_UNREACHABLE; +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +escan(T const& t, V const& v, F&& f) +{ + return detail::escan(t, v, f, tuple_seq{}); +} + +// +// Zip (Transpose) +// + +// Take ((a,b,c,...),(x,y,z,...),...) rank-R0 x rank-R1 input +// to produce ((a,x,...),(b,y,...),(c,z,...),...) rank-R1 x rank-R0 output + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +zip_(Ts const&... ts) +{ + return cute::make_tuple(get(ts)...); +} + +template +CUTE_HOST_DEVICE constexpr +auto +zip(T const& t, seq, seq) +{ + static_assert(conjunction>::value == tuple_size>::value>...>::value, "Mismatched Ranks"); + return cute::make_tuple(zip_(get(t)...)...); +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +zip(T const& t) +{ + if constexpr (is_tuple::value) { + if constexpr (is_tuple>::value) { + return detail::zip(t, tuple_seq{}, tuple_seq>{}); + } else { + return cute::make_tuple(t); + } + } else { + return t; + } + + CUTE_GCC_UNREACHABLE; +} + +// Convenient to pass them in separately +template +CUTE_HOST_DEVICE constexpr +auto +zip(T0 const& t0, T1 const& t1, Ts const&... ts) +{ + return zip(cute::make_tuple(t0, t1, ts...)); +} + +// +// zip2_by -- A guided zip for rank-2 tuples +// Take a tuple like ((A,a),((B,b),(C,c)),d) +// and produce a tuple ((A,(B,C)),(a,(b,c),d)) +// where the rank-2 modes are selected by the terminals of the guide (X,(X,X)) +// + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +zip2_by(T const& t, TG const& guide, seq, seq) +{ + // zip2_by produces the modes like ((A,a),(B,b),...) + auto split = cute::make_tuple(zip2_by(get(t), get(guide))...); + + // Rearrange and append missing modes from t to make ((A,B,...),(a,b,...,x,y)) + return cute::make_tuple(cute::make_tuple(get<0>(get(split))...), + cute::make_tuple(get<1>(get(split))..., get(t)...)); +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +zip2_by(T const& t, TG const& guide) +{ + if constexpr (is_tuple::value) { + constexpr int TR = tuple_size::value; + constexpr int GR = tuple_size::value; + static_assert(TR >= GR, "Mismatched ranks"); + return detail::zip2_by(t, guide, + make_range< 0, GR>{}, + make_range{}); + } else { + static_assert(tuple_size::value == 2, "Mismatched ranks"); + return t; + } + + CUTE_GCC_UNREACHABLE; +} + +/// @return A tuple of the elements of @c t in reverse order. +template +CUTE_HOST_DEVICE constexpr +auto +reverse(T const& t) +{ + if constexpr (is_tuple::value) { + return detail::apply(t, [](auto const&... a){ return cute::make_tuple(a...); }, tuple_rseq{}); + } else { + return t; + } +} + +} // end namespace cute diff --git a/include/cute/arch/cluster_sm90.hpp b/include/cute/arch/cluster_sm90.hpp new file mode 100644 index 0000000000..8fff51be8e --- /dev/null +++ b/include/cute/arch/cluster_sm90.hpp @@ -0,0 +1,245 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +// Config +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && \ + ((__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 8)))) +# define CUTE_ARCH_CLUSTER_SM90_ENABLED +#endif + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && (__CUDACC_VER_MAJOR__ >= 12)) +# define CUTE_ARCH_ELECT_ONE_SM90_ENABLED +#endif + +namespace cute { + +CUTE_DEVICE void cluster_arrive_relaxed() +{ +#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) + asm volatile("barrier.cluster.arrive.relaxed.aligned;\n" : : ); +#else + CUTE_INVALID_CONTROL_PATH("CUTE_ARCH_CLUSTER_SM90_ENABLED is not defined"); +#endif +} + +CUTE_DEVICE void cluster_arrive() +{ +#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) + asm volatile("barrier.cluster.arrive.aligned;\n" : : ); +#else + CUTE_INVALID_CONTROL_PATH("CUTE_ARCH_CLUSTER_SM90_ENABLED is not defined"); +#endif +} + +CUTE_DEVICE void cluster_wait() +{ +#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) + asm volatile("barrier.cluster.wait.aligned;\n" : : ); +#else + CUTE_INVALID_CONTROL_PATH("CUTE_ARCH_CLUSTER_SM90_ENABLED is not defined"); +#endif +} + +CUTE_DEVICE void cluster_sync() +{ +#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) + cluster_arrive(); + cluster_wait(); +#else + CUTE_INVALID_CONTROL_PATH("CUTE_ARCH_CLUSTER_SM90_ENABLED is not defined"); +#endif +} + +// Returns the dim3 grid size in terms of number of clusters. +CUTE_DEVICE dim3 cluster_grid_dims() +{ +#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) + uint32_t x, y, z; + asm volatile("mov.u32 %0, %%nclusterid.x;\n" : "=r"(x) : ); + asm volatile("mov.u32 %0, %%nclusterid.y;\n" : "=r"(y) : ); + asm volatile("mov.u32 %0, %%nclusterid.z;\n" : "=r"(z) : ); + return {x, y, z}; +#elif defined(__CUDA_ARCH__) + // MSVC requires protecting use of gridDim with __CUDA_ARCH__. + return gridDim; +#elif defined(_MSC_VER) + CUTE_INVALID_CONTROL_PATH("cluster_grid_dims() can only be called on device"); + return {0, 0, 0}; +#else + return {0, 0, 0}; +#endif +} + +// Returns the dim3 cluster rank in the grid. +CUTE_DEVICE dim3 cluster_id_in_grid() +{ +#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) + uint32_t x, y, z; + asm volatile("mov.u32 %0, %%clusterid.x;\n" : "=r"(x) : ); + asm volatile("mov.u32 %0, %%clusterid.y;\n" : "=r"(y) : ); + asm volatile("mov.u32 %0, %%clusterid.z;\n" : "=r"(z) : ); + return {x, y, z}; +#elif defined(__CUDA_ARCH__) + // MSVC requires protecting use of blockIdx with __CUDA_ARCH__. + return blockIdx; +#elif defined(_MSC_VER) + CUTE_INVALID_CONTROL_PATH("cluster_id_in_grid() can only be called on device"); + return {0, 0, 0}; +#else + return {0, 0, 0}; +#endif +} + +// Returns the relative dim3 block rank local to the cluster. +CUTE_DEVICE dim3 block_id_in_cluster() +{ +#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) + uint32_t x, y, z; + asm volatile("mov.u32 %0, %%cluster_ctaid.x;\n" : "=r"(x) : ); + asm volatile("mov.u32 %0, %%cluster_ctaid.y;\n" : "=r"(y) : ); + asm volatile("mov.u32 %0, %%cluster_ctaid.z;\n" : "=r"(z) : ); + return {x, y, z}; +#else + return {0,0,0}; +#endif +} + +// Returns the dim3 cluster shape. +CUTE_DEVICE dim3 cluster_shape() +{ +#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) + uint32_t x, y, z; + asm volatile("mov.u32 %0, %%cluster_nctaid.x;\n" : "=r"(x) : ); + asm volatile("mov.u32 %0, %%cluster_nctaid.y;\n" : "=r"(y) : ); + asm volatile("mov.u32 %0, %%cluster_nctaid.z;\n" : "=r"(z) : ); + return {x, y, z}; +#else + return {1,1,1}; +#endif +} + +// Get 1D ctaid in a cluster. +CUTE_DEVICE uint32_t block_rank_in_cluster() +{ +#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) + uint32_t rank; + asm volatile("mov.u32 %0, %%cluster_ctarank;\n" : "=r"(rank) :); + return rank; +#else + return 0; +#endif +} + +// Set the destination block-ID in cluster for a given SMEM Address +CUTE_DEVICE uint32_t set_block_rank(uint32_t smemAddr, uint32_t rank) +{ +#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) + uint32_t result; + asm volatile("mapa.shared::cluster.u32 %0, %1, %2;\n" + : "=r"(result) + : "r"(smemAddr), "r"(rank)); + return result; +#else + return smemAddr; +#endif +} + +// Elect one thread in the warp. The elected thread gets its predicate set to true, all others obtain false. +CUTE_HOST_DEVICE uint32_t elect_one_sync() +{ +#if defined(CUTE_ARCH_ELECT_ONE_SM90_ENABLED) + uint32_t pred = 0; + uint32_t laneid = 0; + asm volatile( + "{\n" + ".reg .b32 %%rx;\n" + ".reg .pred %%px;\n" + " elect.sync %%rx|%%px, %2;\n" + "@%%px mov.s32 %1, 1;\n" + " mov.s32 %0, %%rx;\n" + "}\n" + : "+r"(laneid), "+r"(pred) + : "r"(0xFFFFFFFF)); + return pred; +#elif defined(__CUDA_ARCH__) + return (threadIdx.x % 32) == 0; +#else + return true; +#endif +} + +struct ElectOneLaneIdReturnType { + uint32_t is_leader; + uint32_t leader_lane_id; +}; + +CUTE_HOST_DEVICE +ElectOneLaneIdReturnType +elect_one_leader_sync() +{ +#if defined(CUTE_ARCH_ELECT_ONE_SM90_ENABLED) + uint32_t pred = 0; + uint32_t laneid = 0; + asm volatile( + "{\n" + ".reg .b32 %%rx;\n" + ".reg .pred %%px;\n" + " elect.sync %%rx|%%px, %2;\n" + "@%%px mov.s32 %1, 1;\n" + " mov.s32 %0, %%rx;\n" + "}\n" + : "+r"(laneid), "+r"(pred) + : "r"(0xFFFFFFFF)); + return {pred, laneid}; +#elif defined(__CUDA_ARCH__) + return {(threadIdx.x % 32) == 0, 0}; +#else + return {true, 0}; +#endif +} + +// Store value to remote shared memory in the cluster +CUTE_DEVICE +void +store_shared_remote(uint32_t value, uint32_t smem_addr, uint32_t mbarrier_addr, uint32_t dst_cta_rank) +{ +#if defined(CUTE_ARCH_CLUSTER_SM90_ENABLED) + uint32_t dsmem_addr = set_block_rank(smem_addr, dst_cta_rank); + uint32_t remote_barrier_addr = set_block_rank(mbarrier_addr, dst_cta_rank); + asm volatile("st.async.shared::cluster.mbarrier::complete_tx::bytes.u32 [%0], %1, [%2];" + : : "r"(dsmem_addr), "r"(value), "r"(remote_barrier_addr)); +#endif +} + +} // end namespace cute diff --git a/include/cute/arch/config.hpp b/include/cute/arch/config.hpp new file mode 100644 index 0000000000..84d7779a34 --- /dev/null +++ b/include/cute/arch/config.hpp @@ -0,0 +1,50 @@ +/*************************************************************************************************** + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTLASS_ARCH_MMA_SMxx_ENABLED + +// TMA instructions +#if defined(CUTLASS_ARCH_MMA_SM90_ENABLED) +# define CUTE_ARCH_TMA_SM90_ENABLED +#endif + +#if defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_ENABLED) +# define CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED +#endif + +// STSM +#if defined(CUTLASS_ARCH_MMA_SM90_ENABLED) +# define CUTE_ARCH_STSM_SM90_ENABLED +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/include/cute/arch/copy.hpp b/include/cute/arch/copy.hpp new file mode 100644 index 0000000000..47dbef2f55 --- /dev/null +++ b/include/cute/arch/copy.hpp @@ -0,0 +1,107 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include + +namespace cute +{ + +// +// Direct Copy for any specific types +// + +template +struct UniversalCopy +{ + using SRegisters = S[1]; + using DRegisters = D[1]; + + // Sanity + static_assert(sizeof_bits_v >= 8); + static_assert(sizeof_bits_v >= 8); + + CUTE_HOST_DEVICE static constexpr void + copy(S const& src, + D & dst) + { + dst = src; + } +}; + +// +// Placeholder for the copy algorithm's stronger auto-vectorizing behavior +// that assumes alignment of pointers and dynamic layouts up to MaxVecBits +// + +template +struct AutoVectorizingCopyWithAssumedAlignment + : UniversalCopy> +{ + static_assert(MaxVecBits == 8 || MaxVecBits == 16 || MaxVecBits == 32 || MaxVecBits == 64 || MaxVecBits == 128, + "Expected MaxVecBits to be 8 or 16 or 32 or 64 or 128 for alignment and performance."); +}; + +// +// AutoVectorizingCopy alias assumes maximal alignment of pointers and dynamic strides. +// If this is not the case then AutoVectorizingCopyWithAssumedAlignment should be used instead +// + +using AutoVectorizingCopy = AutoVectorizingCopyWithAssumedAlignment<128>; + +// +// DefaultCopy alias does not assume alignment of pointers or dynamic strides. +// + +using DefaultCopy = AutoVectorizingCopyWithAssumedAlignment<8>; + +// +// Copy policy automatically selecting between +// UniversalCopy and cp.async , based on type and memory space. +// +struct AutoCopyAsync {}; + +// +// Global memory prefetch into L2 +// + +CUTE_HOST_DEVICE static void +prefetch(void const* gmem_ptr) +{ +#if defined(__CUDA_ARCH__) + asm volatile("prefetch.global.L2 [%0];\n" : : "l"(gmem_ptr) : "memory"); +#endif +} + +} // end namespace cute diff --git a/include/cute/arch/copy_sm50.hpp b/include/cute/arch/copy_sm50.hpp new file mode 100644 index 0000000000..925d9ebe37 --- /dev/null +++ b/include/cute/arch/copy_sm50.hpp @@ -0,0 +1,98 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 500 + #define CUTE_ARCH_WARP_SHUFFLE_ENABLED 1 +#endif + +namespace cute +{ +// Shuffle data between thread pair (0, 1), (2, 3), etc. +struct SM50_Shuffle_U32_2x2Trans_XOR1 +{ + using SRegisters = uint32_t[2]; + using DRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t& dst0, uint32_t& dst1) + { +#if defined(CUTE_ARCH_WARP_SHUFFLE_ENABLED) + uint32_t x0 = src0; + uint32_t y0 = __shfl_xor_sync(0xffffffff, x0, 1); + + uint32_t x1 = src1; + uint32_t y1 = __shfl_xor_sync(0xffffffff, x1, 1); + + if (threadIdx.x % 2 == 0) { + dst1 = y0; + } + else { + dst0 = y1; + } +#else + CUTE_INVALID_CONTROL_PATH("Trying to use __shfl_xor_sync without CUTE_ARCH_WARP_SHUFFLE_ENABLED."); +#endif + } +}; + +// Shuffle data between thread pair (0, 4), (1, 5), etc. +struct SM50_Shuffle_U32_2x2Trans_XOR4 +{ + using SRegisters = uint32_t[2]; + using DRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t& dst0, uint32_t& dst1) + { +#if defined(CUTE_ARCH_WARP_SHUFFLE_ENABLED) + uint32_t x0 = threadIdx.x & 4 ? src0 : src1; + uint32_t y0 = __shfl_xor_sync(0xffffffff, x0, 4); + + // Replace detination register with shuffle result. + if (threadIdx.x & 0x4) { + dst0 = y0; + } + else { + dst1 = y0; + } +#else + CUTE_INVALID_CONTROL_PATH("Trying to use __shfl_xor_sync without CUTE_ARCH_WARP_SHUFFLE_ENABLED."); +#endif + } +}; + + +} // end namespace cute diff --git a/include/cute/arch/copy_sm75.hpp b/include/cute/arch/copy_sm75.hpp new file mode 100644 index 0000000000..3d3d37acb0 --- /dev/null +++ b/include/cute/arch/copy_sm75.hpp @@ -0,0 +1,236 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +// Config +#if defined(__clang__) && defined(__CUDA__) + // ldmatrix PTX instructions added in Clang 14: https://reviews.llvm.org/D107046 + // ... but will not work until Clang 15: + // * https://reviews.llvm.org/D121666 + // * https://reviews.llvm.org/D126846 + #define CUTE_ARCH_CLANG_SUPPORTS_LDSM_SM75 (__clang_major__ >= 15) +#endif + +#if defined(__NVCC__) || defined(__CUDACC_RTC__) + // ldmatrix PTX instruction added in CUDA 10.2+ + #define CUTE_ARCH_NVCC_SUPPORTS_LDSM_SM75 ((__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2) || __CUDACC_VER_MAJOR__ >= 11) +#endif + +#if ! defined(CUTE_ARCH_LDSM_SM75_SUPPORTED) + #define CUTE_ARCH_LDSM_SM75_SUPPORTED (CUTE_ARCH_NVCC_SUPPORTS_LDSM_SM75 || CUTE_ARCH_CLANG_SUPPORTS_LDSM_SM75) +#endif + +#if ! defined(CUTE_ARCH_LDSM_SM75_ENABLED) + #define CUTE_ARCH_LDSM_SM75_ENABLED (CUTE_ARCH_LDSM_SM75_SUPPORTED) +#endif + +#if (CUTE_ARCH_LDSM_SM75_ENABLED) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 750 + #define CUTE_ARCH_LDSM_SM75_ACTIVATED 1 +#endif + +namespace cute +{ + +struct SM75_U32x1_LDSM_N +{ + using SRegisters = uint128_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint128_t const& smem_src, + uint32_t& dst) + { +#if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); + asm volatile ("ldmatrix.sync.aligned.x1.m8n8.shared.b16 {%0}, [%1];\n" + : "=r"(dst) + : "r"(smem_int_ptr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ACTIVATED."); +#endif + } +}; + +struct SM75_U32x2_LDSM_N +{ + using SRegisters = uint128_t[1]; + using DRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + copy(uint128_t const& smem_src, + uint32_t& dst0, uint32_t& dst1) + { +#if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); + asm volatile ("ldmatrix.sync.aligned.x2.m8n8.shared.b16 {%0, %1}, [%2];\n" + : "=r"(dst0), "=r"(dst1) + : "r"(smem_int_ptr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ACTIVATED."); +#endif + } +}; + +struct SM75_U32x4_LDSM_N +{ + using SRegisters = uint128_t[1]; + using DRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + copy(uint128_t const& smem_src, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3) + { +#if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); + asm volatile ("ldmatrix.sync.aligned.x4.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3) + : "r"(smem_int_ptr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ACTIVATED."); +#endif + } +}; + +struct SM75_U16x2_LDSM_T +{ + using SRegisters = uint128_t[1]; + using DRegisters = uint32_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint128_t const& smem_src, + uint32_t& dst) + { +#if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); + asm volatile ("ldmatrix.sync.aligned.x1.trans.m8n8.shared.b16 {%0}, [%1];\n" + : "=r"(dst) + : "r"(smem_int_ptr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ACTIVATED."); +#endif + } +}; + +struct SM75_U16x4_LDSM_T +{ + using SRegisters = uint128_t[1]; + using DRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + copy(uint128_t const& smem_src, + uint32_t& dst0, uint32_t& dst1) + { +#if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); + asm volatile ("ldmatrix.sync.aligned.x2.trans.m8n8.shared.b16 {%0, %1}, [%2];\n" + : "=r"(dst0), "=r"(dst1) + : "r"(smem_int_ptr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ACTIVATED."); +#endif + } +}; + +struct SM75_U16x8_LDSM_T +{ + using SRegisters = uint128_t[1]; + using DRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + copy(uint128_t const& smem_src, + uint32_t& dst0, uint32_t& dst1, uint32_t& dst2, uint32_t& dst3) + { +#if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_src); + asm volatile ("ldmatrix.sync.aligned.x4.trans.m8n8.shared.b16 {%0, %1, %2, %3}, [%4];\n" + : "=r"(dst0), "=r"(dst1), "=r"(dst2), "=r"(dst3) + : "r"(smem_int_ptr)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use ldmatrix without CUTE_ARCH_LDSM_SM75_ACTIVATED."); +#endif + } +}; + +// +// Legacy LDSM interfaces that aren't very useful +// + +template +CUTE_HOST_DEVICE +void +copy_ldsm(uint128_t const* const smem_ptr, + T* rmem_ptr) +{ + uint32_t* reg_ptr = reinterpret_cast(rmem_ptr); + + // if constexpr + if (sizeof(T) == 4) { + SM75_U32x1_LDSM_N::copy(smem_ptr[0], reg_ptr[0]); + } + else if (sizeof(T) == 8) { + SM75_U32x2_LDSM_N::copy(smem_ptr[0], reg_ptr[0], reg_ptr[1]); + } + else if (sizeof(T) == 16) { + SM75_U32x4_LDSM_N::copy(smem_ptr[0], reg_ptr[0], reg_ptr[1], reg_ptr[2], reg_ptr[3]); + } + else { + static_assert(sizeof(T) == 4 || sizeof(T) == 8 || sizeof(T) == 16, "sizeof(T) is not supported"); + } +} + +template +CUTE_HOST_DEVICE +void +copy_ldsm_trans(uint128_t const* const smem_ptr, + T* rmem_ptr) +{ + uint32_t* reg_ptr = reinterpret_cast(rmem_ptr); + + // if constexpr + if (sizeof(T) == 4) { + SM75_U16x2_LDSM_T::copy(smem_ptr[0], reg_ptr[0]); + } + else if (sizeof(T) == 8) { + SM75_U16x4_LDSM_T::copy(smem_ptr[0], reg_ptr[0], reg_ptr[1]); + } + else if (sizeof(T) == 16) { + SM75_U16x8_LDSM_T::copy(smem_ptr[0], reg_ptr[0], reg_ptr[1], reg_ptr[2], reg_ptr[3]); + } + else { + static_assert(sizeof(T) == 4 || sizeof(T) == 8 || sizeof(T) == 16, "sizeof(T) is not supported"); + } +} + +} // end namespace cute diff --git a/include/cute/arch/copy_sm80.hpp b/include/cute/arch/copy_sm80.hpp new file mode 100644 index 0000000000..e04181bfe9 --- /dev/null +++ b/include/cute/arch/copy_sm80.hpp @@ -0,0 +1,198 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +// Config +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) +# define CUTE_ARCH_CP_ASYNC_SM80_ENABLED +#endif + +namespace cute +{ + +/// Copy via cp.async with caching at all levels +template +struct SM80_CP_ASYNC_CACHEALWAYS +{ + using SRegisters = TS[1]; + using DRegisters = TD[1]; + + static_assert(sizeof(TS) == sizeof(TD), "cp.async requires sizeof(src_value_type) == sizeof(dst_value_type)"); + static_assert(sizeof(TS) == 4 || sizeof(TS) == 8 || sizeof(TS) == 16, "cp.async sizeof(TS) is not supported"); + + CUTE_HOST_DEVICE static void + copy(TS const& gmem_src, + TD & smem_dst) + { +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) + TS const* gmem_ptr = &gmem_src; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); + asm volatile("cp.async.ca.shared.global.L2::128B [%0], [%1], %2;\n" + :: "r"(smem_int_ptr), + "l"(gmem_ptr), + "n"(sizeof(TS))); +#else + CUTE_INVALID_CONTROL_PATH("Support for cp.async instructions has not been enabled"); +#endif + } +}; + +/// Copy via cp.async with caching at global level +template +struct SM80_CP_ASYNC_CACHEGLOBAL +{ + using SRegisters = TS[1]; + using DRegisters = TD[1]; + + static_assert(sizeof(TS) == sizeof(TD), "cp.async requires sizeof(src_value_type) == sizeof(dst_value_type)"); + static_assert(sizeof(TS) == 16, "cp.async sizeof(TS) is not supported"); + + CUTE_HOST_DEVICE static void + copy(TS const& gmem_src, + TD & smem_dst) + { +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) + TS const* gmem_ptr = &gmem_src; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); + asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], %2;\n" + :: "r"(smem_int_ptr), + "l"(gmem_ptr), + "n"(sizeof(TS))); +#else + CUTE_INVALID_CONTROL_PATH("Support for cp.async instructions has not been enabled"); +#endif + } +}; + +/// Copy via cp.async with caching at all levels +template +struct SM80_CP_ASYNC_CACHEALWAYS_ZFILL +{ + using SRegisters = TS[1]; + using DRegisters = TD[1]; + + static_assert(sizeof(TS) == sizeof(TD), "cp.async requires sizeof(src_value_type) == sizeof(dst_value_type)"); + static_assert(sizeof(TS) == 4 || sizeof(TS) == 8 || sizeof(TS) == 16, "cp.async sizeof(TS) is not supported"); + + CUTE_HOST_DEVICE static void + copy(TS const& gmem_src, + TD & smem_dst, + bool pred) + { +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) + TS const* gmem_ptr = &gmem_src; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); + int src_size = pred ? sizeof(TS) : 0; + asm volatile("cp.async.ca.shared.global.L2::128B [%0], [%1], %2, %3;\n" + :: "r"(smem_int_ptr), + "l"(gmem_ptr), + "n"(sizeof(TS)), + "r"(src_size)); +#else + CUTE_INVALID_CONTROL_PATH("Support for cp.async instructions has not been enabled"); +#endif + } +}; + +/// Copy via cp.async with caching at global level +template +struct SM80_CP_ASYNC_CACHEGLOBAL_ZFILL +{ + using SRegisters = TS[1]; + using DRegisters = TD[1]; + + static_assert(sizeof(TS) == sizeof(TD), "cp.async requires sizeof(src_value_type) == sizeof(dst_value_type)"); + static_assert(sizeof(TS) == 16, "cp.async sizeof(TS) is not supported"); + + CUTE_HOST_DEVICE static void + copy(TS const& gmem_src, + TD & smem_dst, + bool pred) + { +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) + TS const* gmem_ptr = &gmem_src; + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); + int src_size = pred ? sizeof(TS) : 0; + asm volatile("cp.async.cg.shared.global.L2::128B [%0], [%1], %2, %3;\n" + :: "r"(smem_int_ptr), + "l"(gmem_ptr), + "n"(sizeof(TS)), + "r"(src_size)); +#else + CUTE_INVALID_CONTROL_PATH("Support for cp.async instructions has not been enabled"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Establishes an ordering w.r.t previously issued cp.async instructions. Does not block. +CUTE_HOST_DEVICE +void +cp_async_fence() +{ +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) + asm volatile("cp.async.commit_group;\n" ::); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Blocks until all but N previous cp.async.commit_group operations have committed. +template +CUTE_HOST_DEVICE +void +cp_async_wait() +{ +#if defined(CUTE_ARCH_CP_ASYNC_SM80_ENABLED) + if constexpr (N == 0) { + asm volatile("cp.async.wait_all;\n" ::); + } else { + asm volatile("cp.async.wait_group %0;\n" :: "n"(N)); + } +#endif +} + +template +CUTE_HOST_DEVICE +void +cp_async_wait(Int) +{ + return cp_async_wait(); +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // end namespace cute diff --git a/include/cute/arch/copy_sm90.hpp b/include/cute/arch/copy_sm90.hpp new file mode 100644 index 0000000000..bcb3b7d19c --- /dev/null +++ b/include/cute/arch/copy_sm90.hpp @@ -0,0 +1,219 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTE_HOST_DEVICE +#include // CUTE_ARCH_TMA_SMxx_ENABLED +#include + +namespace cute +{ + +struct SM90_U32x1_STSM_N +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint128_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src, + uint128_t & smem_dst) + { +#if defined(CUTE_ARCH_STSM_SM90_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); + asm volatile ("stmatrix.sync.aligned.x1.m8n8.shared.b16 [%0], {%1};\n" + :: "r"(smem_int_ptr), + "r"(src)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED."); +#endif + } +}; + +struct SM90_U32x2_STSM_N +{ + using SRegisters = uint32_t[2]; + using DRegisters = uint128_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, + uint128_t& smem_dst) + { +#if defined(CUTE_ARCH_STSM_SM90_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); + asm volatile ("stmatrix.sync.aligned.x2.m8n8.shared.b16 [%0], {%1, %2};\n" + :: "r"(smem_int_ptr), + "r"(src0), "r"(src1)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED."); +#endif + } +}; + +struct SM90_U32x4_STSM_N +{ + using SRegisters = uint32_t[4]; + using DRegisters = uint128_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint128_t& smem_dst) + { +#if defined(CUTE_ARCH_STSM_SM90_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); + asm volatile ("stmatrix.sync.aligned.x4.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n" + :: "r"(smem_int_ptr), + "r"(src0), "r"(src1), "r"(src2), "r"(src3)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED."); +#endif + } +}; + +struct SM90_U16x2_STSM_T +{ + using SRegisters = uint32_t[1]; + using DRegisters = uint128_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src, + uint128_t& smem_dst) + { +#if defined(CUTE_ARCH_STSM_SM90_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); + asm volatile ("stmatrix.sync.aligned.x1.trans.m8n8.shared.b16 [%0], {%1};\n" + :: "r"(smem_int_ptr), + "r"(src)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED."); +#endif + } +}; + +struct SM90_U16x4_STSM_T +{ + using SRegisters = uint32_t[2]; + using DRegisters = uint128_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, + uint128_t& smem_dst) + { +#if defined(CUTE_ARCH_STSM_SM90_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); + asm volatile ("stmatrix.sync.aligned.x2.trans.m8n8.shared.b16 [%0], {%1, %2};\n" + :: "r"(smem_int_ptr), + "r"(src0), "r"(src1)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED."); +#endif + } +}; + +struct SM90_U16x8_STSM_T +{ + using SRegisters = uint32_t[4]; + using DRegisters = uint128_t[1]; + + CUTE_HOST_DEVICE static void + copy(uint32_t const& src0, uint32_t const& src1, uint32_t const& src2, uint32_t const& src3, + uint128_t& smem_dst) + { +#if defined(CUTE_ARCH_STSM_SM90_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_dst); + asm volatile ("stmatrix.sync.aligned.x4.trans.m8n8.shared.b16 [%0], {%1, %2, %3, %4};\n" + :: "r"(smem_int_ptr), + "r"(src0), "r"(src1), "r"(src2), "r"(src3)); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use stmatrix without CUTE_ARCH_STSM_SM90_ENABLED."); +#endif + } +}; + +// +// Legacy STSM interfaces that aren't very useful +// + +template +CUTE_HOST_DEVICE +void +copy_stsm(T const* const rmem_ptr, + uint128_t* const smem_ptr) +{ + uint32_t const* reg_ptr = reinterpret_cast(rmem_ptr); + + // if constexpr + if (sizeof(T) == 4) { + SM90_U32x1_STSM_N::copy(reg_ptr[0], smem_ptr[0]); + } + else if (sizeof(T) == 8) { + SM90_U32x2_STSM_N::copy(reg_ptr[0], reg_ptr[1], smem_ptr[0]); + } + else if (sizeof(T) == 16) { + SM90_U32x4_STSM_N::copy(reg_ptr[0], reg_ptr[1], reg_ptr[2], reg_ptr[3], smem_ptr[0]); + } + else { + static_assert(sizeof(T) == 4 || sizeof(T) == 8 || sizeof(T) == 16, "sizeof(T) is not supported"); + } +} + +template +CUTE_HOST_DEVICE +void +copy_stsm_trans(T const* const rmem_ptr, + uint128_t* const smem_ptr) +{ + uint32_t const* reg_ptr = reinterpret_cast(rmem_ptr); + + // if constexpr + if (sizeof(T) == 4) { + SM90_U16x2_STSM_T::copy(reg_ptr[0], smem_ptr[0]); + } + else if (sizeof(T) == 8) { + SM90_U16x4_STSM_T::copy(reg_ptr[0], reg_ptr[1], smem_ptr[0]); + } + else if (sizeof(T) == 16) { + SM90_U16x8_STSM_T::copy(reg_ptr[0], reg_ptr[1], reg_ptr[2], reg_ptr[3], smem_ptr[0]); + } + else { + static_assert(sizeof(T) == 4 || sizeof(T) == 8 || sizeof(T) == 16, "sizeof(T) is not supported"); + } +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // end namespace cute + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#include +#include + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cute/arch/copy_sm90_desc.hpp b/include/cute/arch/copy_sm90_desc.hpp new file mode 100644 index 0000000000..cc0bf4a392 --- /dev/null +++ b/include/cute/arch/copy_sm90_desc.hpp @@ -0,0 +1,440 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/numeric_types.h" + +#if !defined(__CUDACC_RTC__) +#include +#include +#endif + +#include + +#include // cute::cast_smem_ptr_to_uint +#include // CUTE_ARCH_TMA_SMxx_ENABLED +#include +#include + +#include +#include +#include +#include + +namespace cute +{ + +////////////////////////////////////////////////////////////////////////////////////////////////////// +/// Barriers are 64-bit of user-managed information used in broadly two types syncronization patterns +/// 1) arrive/wait on threads (usage: cp.async and warp-specialized kernels) +/// 2) transaction-based (usage: TMA transaction where a CTA issues one transaction) +////////////////////////////////////////////////////////////////////////////////////////////////////// + +// Initialize barrier present in shared memory +CUTE_HOST_DEVICE +void +initialize_barrier(uint64_t& smem_barrier, // 64 bits user-manged barrier in smem + int thread_count = 1) // Thread count expected to arrive/wait on this barrier +{ +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_barrier); + asm volatile ("mbarrier.init.shared::cta.b64 [%0], %1;\n" + :: "r"(smem_int_ptr), + "r"(thread_count)); +#endif +} + +// Set the number of bytes transfered per transaction and perform an arrive operation as well +CUTE_HOST_DEVICE +void +set_barrier_transaction_bytes(uint64_t& smem_barrier, // 64 bits user-manged barrier in smem + uint32_t bytes) // Number of bytes transfered by per TMA transaction +{ +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_barrier); + asm volatile ("mbarrier.arrive.expect_tx.shared::cta.b64 _, [%0], %1;\n" + :: "r"(smem_int_ptr), + "r"(bytes)); +#endif +} + +// Barrier wait +CUTE_HOST_DEVICE +void +wait_barrier(uint64_t& smem_barrier, // 64 bits user-manged barrier in smem + int phase_bit) // Current phase bit the barrier waiting to flip +{ +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_barrier); + asm volatile( + "{\n" + ".reg .pred P1;\n" + "LAB_WAIT:\n" + "mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1;\n" + "@P1 bra DONE;\n" + "bra LAB_WAIT;\n" + "DONE:\n" + "}\n" + :: "r"(smem_int_ptr), + "r"(phase_bit)); + +#endif +} + +// Barrier arrive +CUTE_HOST_DEVICE +void +arrive_barrier(uint64_t& smem_barrier) // 64 bits user-manged barrier in smem +{ +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(&smem_barrier); + asm volatile( + "{\n" + ".reg .b64 state; \n" + "mbarrier.arrive.shared::cta.b64 state, [%0];\n" + "}\n" + :: "r"(smem_int_ptr)); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// TMA Descriptor and utilities +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace TMA { + +enum class SmemSwizzleBits : uint8_t { + DISABLE = 0, + B32 = 1, + B64 = 2, + B128 = 3, +}; + +enum class SmemSwizzleBase : uint8_t { + SWIZZLE_BASE_16B = 0, +}; + +enum class OOBFill : uint8_t { + ZERO = 0, + CONSTANT = 1, +}; + +CUTE_HOST_DEVICE char const* to_string(OOBFill const& t) { + switch (t) { + case OOBFill::ZERO: return "ZERO"; + case OOBFill::CONSTANT: return "CONSTANT"; + } + return nullptr; +} + +enum class L2Promotion : uint8_t { + DISABLE = 0, + B64 = 1, + B128 = 2, + B256 = 3, +}; + +CUTE_HOST_DEVICE char const* to_string(L2Promotion const& t) { + switch (t) { + case L2Promotion::DISABLE: return "DISABLE"; + case L2Promotion::B64: return "B64"; + case L2Promotion::B128: return "B128"; + case L2Promotion::B256: return "B256"; + } + return nullptr; +} + +// Aux parameters which are independent with the problem size +struct DescriptorAuxParams { + OOBFill oobfill_ = OOBFill::ZERO; + L2Promotion l2promo_ = L2Promotion::DISABLE; +}; + +enum class CacheHintSm90 : uint64_t { + EVICT_NORMAL = 0x1000000000000000, + EVICT_FIRST = 0x12F0000000000000, + EVICT_LAST = 0x14F0000000000000, +}; + +#if (__CUDACC_VER_MAJOR__ >= 12) + +#if !defined(__CUDACC_RTC__) +/// @return The TMA descriptor datatype enum corresponding to T. +template +inline CUtensorMapDataType +to_CUtensorMapDataType() { + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT8; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT16; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT32; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_UINT64; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_INT32; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_INT64; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_FLOAT16; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_FLOAT32; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_FLOAT64; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_BFLOAT16; } else + if constexpr (is_same_v) { return CU_TENSOR_MAP_DATA_TYPE_TFLOAT32; } else + { static_assert(sizeof(T) < 0, "Unknown TMA Format!"); } +} + +inline CUtensorMapSwizzle +to_CUtensorMapSwizzle(SmemSwizzleBits const& t, SmemSwizzleBase const& b) { + switch (t) { + default: assert(false && "Unsupported pair of SmemSwizzleBits and SmemSwizzleBase!"); + case SmemSwizzleBits::DISABLE: + assert((b == SmemSwizzleBase::SWIZZLE_BASE_16B) && "Expected 16B swizzle base for 0B swizzle bits."); + return CU_TENSOR_MAP_SWIZZLE_NONE; + case SmemSwizzleBits::B32: + assert((b == SmemSwizzleBase::SWIZZLE_BASE_16B) && "Expected 16B swizzle base for 32B swizzle bits."); + return CU_TENSOR_MAP_SWIZZLE_32B; + case SmemSwizzleBits::B64: + assert((b == SmemSwizzleBase::SWIZZLE_BASE_16B) && "Expected 16B swizzle base for 64B swizzle bits."); + return CU_TENSOR_MAP_SWIZZLE_64B; + case SmemSwizzleBits::B128: + assert((b == SmemSwizzleBase::SWIZZLE_BASE_16B) && "Expected 16B swizzle base for 128B swizzle bits."); + return CU_TENSOR_MAP_SWIZZLE_128B; + } +} + +inline CUtensorMapFloatOOBfill +to_CUtensorMapFloatOOBfill(OOBFill const& t) { + switch(t) { + default: assert(false && "Unknown OOBFill!"); + case OOBFill::ZERO: return CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE; + case OOBFill::CONSTANT: return CU_TENSOR_MAP_FLOAT_OOB_FILL_NAN_REQUEST_ZERO_FMA; + } +} + +inline CUtensorMapL2promotion +to_CUtensorMapL2promotion(L2Promotion const& t) { + switch(t) { + default: assert(false && "Unknown L2Promotion!"); + case L2Promotion::DISABLE: return CU_TENSOR_MAP_L2_PROMOTION_NONE; + case L2Promotion::B64: return CU_TENSOR_MAP_L2_PROMOTION_L2_64B; + case L2Promotion::B128: return CU_TENSOR_MAP_L2_PROMOTION_L2_128B; + case L2Promotion::B256: return CU_TENSOR_MAP_L2_PROMOTION_L2_256B; + } +} + +#endif // !defined(__CUDACC_RTC__) + +#endif // (__CUDACC_VER_MAJOR__ >= 12) + +} // end namespace TMA + +#if (__CUDACC_VER_MAJOR__ >= 12) && !defined(__CUDACC_RTC__) + using TmaDescriptor = CUtensorMap; + using Im2ColTmaDescriptor = CUtensorMap; +#else + using TmaDescriptor = struct alignas(64) { char bytes[128]; }; + using Im2ColTmaDescriptor = struct alignas(64) { char bytes[128]; }; +#endif +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// Initiates a TensorMap Prefetch +//////////////////////////////////////////////////////////////////////////////////////////////////// + +CUTE_HOST_DEVICE +void +prefetch_tma_descriptor(TmaDescriptor const* desc_ptr) +{ +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + // Prefetch TMA Descriptor using generic addressing (i.e. no specific state space: const or param) + asm volatile ( + "prefetch.tensormap [%0];" + : + : "l"(gmem_int_desc) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use TMA Descriptor Prefetch without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// Perform a TensorMap modification (by each field) +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Replace tensor pointer directly in GMEM +CUTE_HOST_DEVICE +void +tma_descriptor_replace_addr_in_global_mem(TmaDescriptor const* desc_ptr, + void const* const new_tensor_ptr) +{ +#if defined(CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint64_t const new_desc_addr = reinterpret_cast(new_tensor_ptr); + asm volatile ( + "tensormap.replace.tile.global_address.global.b1024.b64 [%0], %1;" + :: "l"(gmem_int_desc), "l"(new_desc_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Using TMA Descriptor modification without CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED and CUDA 12.3"); +#endif +} + +// Replace tensor pointer by bringing the tensormap from GMEM into the shared memory +CUTE_HOST_DEVICE +void +tma_descriptor_replace_addr_in_shared_mem(TmaDescriptor& smem_desc, + void const* const new_tensor_ptr) +{ +#if defined(CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED) + uint32_t smem_int_desc = cast_smem_ptr_to_uint(&smem_desc); + uint64_t const new_desc_addr = reinterpret_cast(new_tensor_ptr); + asm volatile ( + "tensormap.replace.tile.global_address.shared::cta.b1024.b64 [%0], %1;" + :: "r"(smem_int_desc), "l"(new_desc_addr)); +#else + CUTE_INVALID_CONTROL_PATH("Using TMA Descriptor modification without CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED and CUDA 12.3"); +#endif +} + +// Replace tensor dims and strides for GEMMs by bringing the tensormap from GMEM into the shared memory +CUTE_HOST_DEVICE +void +tma_descriptor_replace_dims_strides_in_shared_mem(TmaDescriptor & smem_desc, + cute::array const& prob_shape, + cute::array const& prob_stride) +{ +#if defined(CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED) + uint32_t smem_int_desc = cast_smem_ptr_to_uint(&smem_desc); + uint64_t const smem_int64_desc = 0; + asm volatile ( + "cvt.u64.u32 %0, %1;" + :: "l"(smem_int64_desc), "r"(smem_int_desc)); + asm volatile ( + "tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 0, %1;" + :: "l"(smem_int64_desc), "r"(prob_shape[0])); + asm volatile ( + "tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 1, %1;" + :: "l"(smem_int64_desc), "r"(prob_shape[1])); + asm volatile ( + "tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 2, %1;" + :: "l"(smem_int64_desc), "r"(prob_shape[2])); + asm volatile ( + "tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 3, %1;" + :: "l"(smem_int64_desc), "r"(prob_shape[3])); + asm volatile ( + "tensormap.replace.tile.global_dim.shared::cta.b1024.b32 [%0], 4, %1;" + :: "l"(smem_int64_desc), "r"(prob_shape[4])); + // Strides must be a multiple of 16. Also, stride for the intermost dimension is implicitly 1 + #if ((__CUDACC_VER_MAJOR__ > 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ >= 5))) + asm volatile ( + "tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 0, %1;" + :: "l"(smem_int64_desc), "l"(prob_stride[1])); + asm volatile ( + "tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 1, %1;" + :: "l"(smem_int64_desc), "l"(prob_stride[2])); + asm volatile ( + "tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 2, %1;" + :: "l"(smem_int64_desc), "l"(prob_stride[3])); + asm volatile ( + "tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 3, %1;" + :: "l"(smem_int64_desc), "l"(prob_stride[4])); + #else + // 4 LSBs are not included + asm volatile ( + "tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 0, %1;" + :: "l"(smem_int64_desc), "l"(prob_stride[1] >> 4)); + asm volatile ( + "tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 1, %1;" + :: "l"(smem_int64_desc), "l"(prob_stride[2] >> 4)); + asm volatile ( + "tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 2, %1;" + :: "l"(smem_int64_desc), "l"(prob_stride[3] >> 4)); + asm volatile ( + "tensormap.replace.tile.global_stride.shared::cta.b1024.b64 [%0], 3, %1;" + :: "l"(smem_int64_desc), "l"(prob_stride[4] >> 4)); + #endif +#else + CUTE_INVALID_CONTROL_PATH("Using TMA Descriptor modification without CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED and CUDA 12.3"); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// Perform a fused copy and fence operation (needed when modifying tensormap in shared memory) +//////////////////////////////////////////////////////////////////////////////////////////////////// + +CUTE_HOST_DEVICE +void +tma_descriptor_cp_fence_release(TmaDescriptor const* gmem_desc_ptr, TmaDescriptor& smem_desc) +{ +#if defined(CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(gmem_desc_ptr); + uint32_t smem_int_desc = cast_smem_ptr_to_uint(&smem_desc); + asm volatile ( + "tensormap.cp_fenceproxy.global.shared::cta.tensormap::generic.release.gpu.sync.aligned [%0], [%1], 128;" + :: "l"(gmem_int_desc), "r"(smem_int_desc)); +#else + CUTE_INVALID_CONTROL_PATH("Using TMA Descriptor modification without CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED and CUDA 12.3"); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// Perform a release fence operation (needed when modifying tensormap directly in GMEM) +//////////////////////////////////////////////////////////////////////////////////////////////////// + +CUTE_HOST_DEVICE +void +tma_descriptor_fence_release() +{ +#if defined(CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED) + asm volatile ("fence.proxy.tensormap::generic.release.gpu;"); +#else + CUTE_INVALID_CONTROL_PATH("Using TMA Descriptor modification without CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED and CUDA 12.3"); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// Perform a acquire fence operation +//////////////////////////////////////////////////////////////////////////////////////////////////// + +CUTE_HOST_DEVICE +void +tma_descriptor_fence_acquire(TmaDescriptor const* desc_ptr) +{ +#if defined(CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + asm volatile ( + "fence.proxy.tensormap::generic.acquire.gpu [%0], 128;" + : + : "l"(gmem_int_desc) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Using TMA Descriptor modification without CUTE_ARCH_DEVICE_MODIFIABLE_TMA_SM90_ENABLED and CUDA 12.3"); +#endif +} + +/////////////////////////////////////////////////////////////////////////////// + +} // end namespace cute diff --git a/include/cute/arch/copy_sm90_tma.hpp b/include/cute/arch/copy_sm90_tma.hpp new file mode 100644 index 0000000000..fb33d63cad --- /dev/null +++ b/include/cute/arch/copy_sm90_tma.hpp @@ -0,0 +1,1395 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include // CUTE_ARCH_TMA_SMxx_ENABLED +#include +#include +#include "cutlass/arch/synclog.hpp" + +namespace cute +{ + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// TMA_LOAD : Initiates a TMA copy from global memory to shared memory +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM90_TMA_LOAD_1D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); + asm volatile ( + "cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3}], [%2], %4;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } + + struct PREFETCH + { + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + int32_t const& crd0) + { + #if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + asm volatile ( + "cp.async.bulk.prefetch.tensor.1d.L2.global" + " [%0, {%1}];" + : + : "l"(gmem_int_desc), + "r"(crd0) + : "memory"); + #else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); + #endif + } + }; +}; + +struct SM90_TMA_LOAD_2D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); + asm volatile ( + "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4}], [%2], %5;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1), "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } + + struct PREFETCH + { + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + int32_t const& crd0, int32_t const& crd1) + { + #if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + asm volatile ( + "cp.async.bulk.prefetch.tensor.2d.L2.global" + " [%0, {%1, %2}];" + : + : "l"(gmem_int_desc), + "r"(crd0), "r"(crd1) + : "memory"); + #else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); + #endif + } + }; +}; + +struct SM90_TMA_LOAD_3D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); + asm volatile ( + "cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4, %5}], [%2], %6;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1), "r"(crd2), "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } + + struct PREFETCH + { + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { + #if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + asm volatile ( + "cp.async.bulk.prefetch.tensor.3d.L2.global" + " [%0, {%1, %2, %3}];" + : + : "l"(gmem_int_desc), + "r"(crd0), "r"(crd1), "r"(crd2) + : "memory"); + #else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); + #endif + } + }; +}; + +struct SM90_TMA_LOAD_4D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); + asm volatile ( + "cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4, %5, %6}], [%2], %7;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } + + struct PREFETCH + { + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) + { + #if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + asm volatile ( + "cp.async.bulk.prefetch.tensor.4d.L2.global" + " [%0, {%1, %2, %3, %4}];" + : + : "l"(gmem_int_desc), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3) + : "memory"); + #else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); + #endif + } + }; +}; + +struct SM90_TMA_LOAD_5D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); + asm volatile ( + "cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.L2::cache_hint" + " [%0], [%1, {%3, %4, %5, %6, %7}], [%2], %8;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4), "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } + + struct PREFETCH + { + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) + { + #if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + asm volatile ( + "cp.async.bulk.prefetch.tensor.5d.L2.global" + " [%0, {%1, %2, %3, %4, %5}];" + : + : "l"(gmem_int_desc), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4) + : "memory"); + #else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); + #endif + } + }; +}; + +struct SM90_TMA_LOAD +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0) + { + return SM90_TMA_LOAD_1D::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1) + { + return SM90_TMA_LOAD_2D::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { + return SM90_TMA_LOAD_3D::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1, crd2); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) + { + return SM90_TMA_LOAD_4D::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1, crd2, crd3); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) + { + return SM90_TMA_LOAD_5D::copy(desc_ptr, mbar_ptr, cache_hint, smem_ptr, crd0, crd1, crd2, crd3, crd4); + } + + struct PREFETCH + { + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + int32_t const& crd0) + { + return SM90_TMA_LOAD_1D::PREFETCH::copy(desc_ptr, crd0); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + int32_t const& crd0, int32_t const& crd1) + { + return SM90_TMA_LOAD_2D::PREFETCH::copy(desc_ptr, crd0, crd1); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { + return SM90_TMA_LOAD_3D::PREFETCH::copy(desc_ptr, crd0, crd1, crd2); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) + { + return SM90_TMA_LOAD_4D::PREFETCH::copy(desc_ptr, crd0, crd1, crd2, crd3); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) + { + return SM90_TMA_LOAD_5D::PREFETCH::copy(desc_ptr, crd0, crd1, crd2, crd3, crd4); + } + }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// TMA_LOAD im2col: Initiates a TMA copy, in im2col mode, from global memory to shared memory +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM90_TMA_LOAD_IM2COL_3D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_n, + uint16_t const& offset_w) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); + // Copy from global to shared::cluster. + asm volatile ( + "cp.async.bulk.tensor.3d.shared::cluster.global.im2col.mbarrier::complete_tx::bytes" + " [%0], [%1, {%3, %4, %5}], [%2], {%6};" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(coord_c), "r"(coord_w), "r"(coord_n), + "h"(offset_w) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } + + struct PREFETCH + { + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_n, + uint16_t const& offset_w) + { + #if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + asm volatile ( + "cp.async.bulk.prefetch.tensor.3d.L2.global.im2col" + " [%0, {%1, %2, %3}], {%4};" + : + : "l"(gmem_int_desc), + "r"(coord_c), "r"(coord_w), "r"(coord_n), + "h"(offset_w) + : "memory"); + #else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); + #endif + } + }; +}; + +struct SM90_TMA_LOAD_IM2COL_4D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_n, + uint16_t const& offset_w, uint16_t const& offset_h) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); + // Copy from global to shared::cluster. + asm volatile ( + "cp.async.bulk.tensor.4d.shared::cluster.global.im2col.mbarrier::complete_tx::bytes" + " [%0], [%1, {%3, %4, %5, %6}], [%2], {%7, %8};" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_n), + "h"(offset_w), "h"(offset_h) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } + + struct PREFETCH + { + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_n, + uint16_t const& offset_w, uint16_t const& offset_h) + { + #if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + asm volatile ( + "cp.async.bulk.prefetch.tensor.4d.L2.global.im2col" + " [%0, {%1, %2, %3, %4}], {%5, %6};" + : + : "l"(gmem_int_desc), + "r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_n), + "h"(offset_w), "h"(offset_h) + : "memory"); + #else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); + #endif + } + }; +}; + +struct SM90_TMA_LOAD_IM2COL_5D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_d, int32_t const& coord_n, + uint16_t const& offset_w, uint16_t const& offset_h, uint16_t const& offset_d) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); + // Copy from global to shared::cluster. + asm volatile ( + "cp.async.bulk.tensor.5d.shared::cluster.global.im2col.mbarrier::complete_tx::bytes" + " [%0], [%1, {%3, %4, %5, %6, %7}], [%2], {%8, %9, %10};" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_d), "r"(coord_n), + "h"(offset_w), "h"(offset_h), "h"(offset_d) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } + + struct PREFETCH + { + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_d, int32_t const& coord_n, + uint16_t const& offset_w, uint16_t const& offset_h, uint16_t const& offset_d) + { + #if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + asm volatile ( + "cp.async.bulk.prefetch.tensor.5d.L2.global.im2col" + " [%0, {%1, %2, %3, %4, %5}], {%6, %7, %8};" + : + : "l"(gmem_int_desc), + "r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_d), "r"(coord_n), + "h"(offset_w), "h"(offset_h), "h"(offset_d) + : "memory"); + #else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); + #endif + } + }; +}; + +struct SM90_TMA_LOAD_IM2COL +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_n, + uint16_t const& offset_w) + { + return SM90_TMA_LOAD_IM2COL_3D::copy(desc_ptr, mbar_ptr, smem_ptr, + coord_c, coord_w, coord_n, + offset_w); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_n, + uint16_t const& offset_w, uint16_t const& offset_h) + { + return SM90_TMA_LOAD_IM2COL_4D::copy(desc_ptr, mbar_ptr, smem_ptr, + coord_c, coord_w, coord_h, coord_n, + offset_w, offset_h); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_d, int32_t const& coord_n, + uint16_t const& offset_w, uint16_t const& offset_h, uint16_t const& offset_d) + { + return SM90_TMA_LOAD_IM2COL_5D::copy(desc_ptr, mbar_ptr, smem_ptr, + coord_c, coord_w, coord_h, coord_d, coord_n, + offset_w, offset_h, offset_d); + } + + struct PREFETCH + { + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_n, + uint16_t const& offset_w) + { + return SM90_TMA_LOAD_IM2COL_3D::PREFETCH::copy(desc_ptr, + coord_c, coord_w, coord_n, + offset_w); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_n, + uint16_t const& offset_w, uint16_t const& offset_h) + { + return SM90_TMA_LOAD_IM2COL_4D::PREFETCH::copy(desc_ptr, + coord_c, coord_w, coord_h, coord_n, + offset_w, offset_h); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_d, int32_t const& coord_n, + uint16_t const& offset_w, uint16_t const& offset_h, uint16_t const& offset_d) + { + return SM90_TMA_LOAD_IM2COL_5D::PREFETCH::copy(desc_ptr, + coord_c, coord_w, coord_h, coord_d, coord_n, + offset_w, offset_h, offset_d); + } + }; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// TMA_LOAD_MULTICAST: Initiates a TMA copy from global memory to shared memory +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM90_TMA_LOAD_MULTICAST_1D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); + asm volatile ( + "cp.async.bulk.tensor.1d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%4}], [%2], %3, %5;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "h"(multicast_mask), + "r"(crd0), "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_LOAD_MULTICAST_2D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); + asm volatile ( + "cp.async.bulk.tensor.2d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%4, %5}], [%2], %3, %6;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "h"(multicast_mask), + "r"(crd0), "r"(crd1), "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_LOAD_MULTICAST_3D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); + asm volatile ( + "cp.async.bulk.tensor.3d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%4, %5, %6}], [%2], %3, %7;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "h"(multicast_mask), + "r"(crd0), "r"(crd1), "r"(crd2), "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_LOAD_MULTICAST_4D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); + asm volatile ( + "cp.async.bulk.tensor.4d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%4, %5, %6, %7}], [%2], %3, %8;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "h"(multicast_mask), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_LOAD_MULTICAST_5D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); + asm volatile ( + "cp.async.bulk.tensor.5d.shared::cluster.global.mbarrier::complete_tx::bytes.multicast::cluster.L2::cache_hint" + " [%0], [%1, {%4, %5, %6, %7, %8}], [%2], %3, %9;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "h"(multicast_mask), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4), "l"(cache_hint) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_LOAD_MULTICAST +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0) + { + return SM90_TMA_LOAD_MULTICAST_1D::copy(desc_ptr, mbar_ptr, multicast_mask, cache_hint, smem_ptr, crd0); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1) + { + return SM90_TMA_LOAD_MULTICAST_2D::copy(desc_ptr, mbar_ptr, multicast_mask, cache_hint, smem_ptr, crd0, crd1); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { + return SM90_TMA_LOAD_MULTICAST_3D::copy(desc_ptr, mbar_ptr, multicast_mask, cache_hint, smem_ptr, crd0, crd1, crd2); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) + { + return SM90_TMA_LOAD_MULTICAST_4D::copy(desc_ptr, mbar_ptr, multicast_mask, cache_hint, smem_ptr, crd0, crd1, crd2, crd3); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, uint64_t cache_hint, + void * smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) + { + return SM90_TMA_LOAD_MULTICAST_5D::copy(desc_ptr, mbar_ptr, multicast_mask, cache_hint, smem_ptr, crd0, crd1, crd2, crd3, crd4); + } + + using PREFETCH = typename SM90_TMA_LOAD::PREFETCH; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// TMA_LOAD_MULTICAST im2col: Initiates a TMA copy, in im2col mode, from global memory to shared memory +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM90_TMA_LOAD_IM2COL_MULTICAST_3D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_n, + uint16_t const& offset_w) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); + // Copy from global to shared::cluster. + asm volatile ( + "cp.async.bulk.tensor.3d.shared::cluster.global.im2col.mbarrier::complete_tx::bytes.multicast::cluster" + " [%0], [%1, {%3, %4, %5}], [%2], {%6}, %7;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(coord_c), "r"(coord_w), "r"(coord_n), + "h"(offset_w), + "h"(multicast_mask) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_LOAD_IM2COL_MULTICAST_4D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_n, + uint16_t const& offset_w, uint16_t const& offset_h) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); + // Copy from global to shared::cluster. + asm volatile ( + "cp.async.bulk.tensor.4d.shared::cluster.global.im2col.mbarrier::complete_tx::bytes.multicast::cluster" + " [%0], [%1, {%3, %4, %5, %6}], [%2], {%7, %8}, %9;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_n), + "h"(offset_w), "h"(offset_h), + "h"(multicast_mask) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_LOAD_IM2COL_MULTICAST_5D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_d, int32_t const& coord_n, + uint16_t const& offset_w, uint16_t const& offset_h, uint16_t const& offset_d) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_load(__LINE__, gmem_int_desc, smem_int_mbar, smem_int_ptr); + // Copy from global to shared::cluster. + asm volatile ( + "cp.async.bulk.tensor.5d.shared::cluster.global.im2col.mbarrier::complete_tx::bytes.multicast::cluster" + " [%0], [%1, {%3, %4, %5, %6, %7}], [%2], {%8, %9, %10}, %11;" + : + : "r"(smem_int_ptr), "l"(gmem_int_desc), "r"(smem_int_mbar), + "r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_d), "r"(coord_n), + "h"(offset_w), "h"(offset_h), "h"(offset_d), + "h"(multicast_mask) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_LOAD_IM2COL_MULTICAST +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_n, + uint16_t const& offset_w) + { + return SM90_TMA_LOAD_IM2COL_MULTICAST_3D::copy(desc_ptr, mbar_ptr, multicast_mask, + smem_ptr, + coord_c, coord_w, coord_n, + offset_w); + } + + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_n, + uint16_t const& offset_w, uint16_t const& offset_h) + { + return SM90_TMA_LOAD_IM2COL_MULTICAST_4D::copy(desc_ptr, mbar_ptr, multicast_mask, + smem_ptr, + coord_c, coord_w, coord_h, coord_n, + offset_w, offset_h); + } + + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, uint64_t* mbar_ptr, uint16_t multicast_mask, + void * smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_d, int32_t const& coord_n, + uint16_t const& offset_w, uint16_t const& offset_h, uint16_t const& offset_d) + { + return SM90_TMA_LOAD_IM2COL_MULTICAST_5D::copy(desc_ptr, mbar_ptr, multicast_mask, + smem_ptr, + coord_c, coord_w, coord_h, coord_d, coord_n, + offset_w, offset_h, offset_d); + } + + using PREFETCH = typename SM90_TMA_LOAD_IM2COL::PREFETCH; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// TMA_STORE : Initiates a TMA copy from shared memory to global memory +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM90_TMA_STORE_1D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + void const* smem_ptr, + int32_t const& crd0) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); + asm volatile ( + "cp.async.bulk.tensor.1d.global.shared::cta.bulk_group [%0, {%2}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), + "r"(crd0) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_STORE_2D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + void const* smem_ptr, + int32_t const& crd0, int32_t const& crd1) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); + asm volatile ( + "cp.async.bulk.tensor.2d.global.shared::cta.bulk_group [%0, {%2, %3}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), + "r"(crd0), "r"(crd1) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_STORE_3D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + void const* smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); + asm volatile ( + "cp.async.bulk.tensor.3d.global.shared::cta.bulk_group [%0, {%2, %3, %4}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), + "r"(crd0), "r"(crd1), "r"(crd2) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_STORE_4D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + void const* smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); + asm volatile ( + "cp.async.bulk.tensor.4d.global.shared::cta.bulk_group [%0, {%2, %3, %4, %5}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_STORE_5D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + void const* smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); + asm volatile ( + "cp.async.bulk.tensor.5d.global.shared::cta.bulk_group [%0, {%2, %3, %4, %5, %6}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_STORE +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + void const* smem_ptr, + int32_t const& crd0) + { + return SM90_TMA_STORE_1D::copy(desc_ptr, smem_ptr, crd0); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + void const* smem_ptr, + int32_t const& crd0, int32_t const& crd1) + { + return SM90_TMA_STORE_2D::copy(desc_ptr, smem_ptr, crd0, crd1); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + void const* smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { + return SM90_TMA_STORE_3D::copy(desc_ptr, smem_ptr, crd0, crd1, crd2); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + void const* smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) + { + return SM90_TMA_STORE_4D::copy(desc_ptr, smem_ptr, crd0, crd1, crd2, crd3); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + void const* smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) + { + return SM90_TMA_STORE_5D::copy(desc_ptr, smem_ptr, crd0, crd1, crd2, crd3, crd4); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// TMA_STORE im2col: Initiates a TMA copy, in im2col mode, from shared memory to global memory +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM90_TMA_STORE_IM2COL_3D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + void const* smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_n) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); + asm volatile ( + "cp.async.bulk.tensor.3d.global.shared::cta.im2col_no_offs.bulk_group" + " [%0, {%2, %3, %4}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), + "r"(coord_c), "r"(coord_w), "r"(coord_n) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_STORE_IM2COL_4D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + void const* smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_n) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); + asm volatile ( + "cp.async.bulk.tensor.4d.global.shared::cta.im2col_no_offs.bulk_group" + " [%0, {%2, %3, %4, %5}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), + "r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_n) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_STORE_IM2COL_5D +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + void const* smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_d, int32_t const& coord_n) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); + asm volatile ( + "cp.async.bulk.tensor.5d.global.shared::cta.im2col_no_offs.bulk_group" + " [%0, {%2, %3, %4, %5, %6}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), + "r"(coord_c), "r"(coord_w), "r"(coord_h), "r"(coord_d), "r"(coord_n) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_STORE_IM2COL +{ + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + void const* smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_n) + { + return SM90_TMA_STORE_IM2COL_3D::copy(desc_ptr, smem_ptr, coord_c, coord_w, coord_n); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + void const* smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_n) + { + return SM90_TMA_STORE_IM2COL_4D::copy(desc_ptr, smem_ptr, coord_c, coord_w, coord_h, coord_n); + } + CUTE_HOST_DEVICE static void + copy(void const* desc_ptr, + void const* smem_ptr, + int32_t const& coord_c, int32_t const& coord_w, int32_t const& coord_h, int32_t const& coord_d, int32_t const& coord_n) + { + return SM90_TMA_STORE_IM2COL_5D::copy(desc_ptr, smem_ptr, coord_c, coord_w, coord_h, coord_d, coord_n); + } +}; + +// Fence for smem stores for subsequent TMA_STORE +CUTE_HOST_DEVICE static void +tma_store_fence() { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + cutlass::arch::synclog_emit_fence_view_async_shared(__LINE__); + asm volatile ("fence.proxy.async.shared::cta;"); +#elif defined(__CUDA_ARCH__) + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif +} + +// Indicate arrival of warp issuing TMA_STORE +CUTE_HOST_DEVICE static void +tma_store_arrive() { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + cutlass::arch::synclog_emit_tma_store_arrive(__LINE__); + asm volatile("cp.async.bulk.commit_group;"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif +} + +// Wait until at most Count committed TMA_STOREs are pending and all prior commits are complete +template +CUTE_HOST_DEVICE static void +tma_store_wait() { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + asm volatile( + "cp.async.bulk.wait_group.read %0;" + : + : "n"(Count) + : "memory"); + cutlass::arch::synclog_emit_tma_store_wait(__LINE__, Count); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// TMA_REDUCE_ADD : Initiates a TMA reduce-add from shared memory to global memory +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM90_TMA_REDUCE_ADD_1D +{ + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, + void const* const smem_ptr, + int32_t const& crd0) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); + asm volatile ( + "cp.reduce.async.bulk.tensor.1d.global.shared::cta.add.bulk_group [%0, {%2}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), + "r"(crd0) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_REDUCE_ADD_2D +{ + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, + void const* const smem_ptr, + int32_t const& crd0, int32_t const& crd1) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); + asm volatile ( + "cp.reduce.async.bulk.tensor.2d.global.shared::cta.add.bulk_group [%0, {%2, %3}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), + "r"(crd0), "r"(crd1) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_REDUCE_ADD_3D +{ + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, + void const* const smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); + asm volatile ( + "cp.reduce.async.bulk.tensor.3d.global.shared::cta.add.bulk_group [%0, {%2, %3, %4}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), + "r"(crd0), "r"(crd1), "r"(crd2) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_REDUCE_ADD_4D +{ + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, + void const* const smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); + asm volatile ( + "cp.reduce.async.bulk.tensor.4d.global.shared::cta.add.bulk_group [%0, {%2, %3, %4, %5}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_REDUCE_ADD_5D +{ + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, + void const* const smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint64_t gmem_int_desc = reinterpret_cast(desc_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_tma_store(__LINE__, gmem_int_desc, smem_int_ptr); + asm volatile ( + "cp.reduce.async.bulk.tensor.5d.global.shared::cta.add.bulk_group [%0, {%2, %3, %4, %5, %6}], [%1];" + : + : "l"(gmem_int_desc), "r"(smem_int_ptr), + "r"(crd0), "r"(crd1), "r"(crd2), "r"(crd3), "r"(crd4) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use tma without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_TMA_REDUCE_ADD +{ + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, + void const* const smem_ptr, + int32_t const& crd0) + { + return SM90_TMA_REDUCE_ADD_1D::copy(desc_ptr, smem_ptr, crd0); + } + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, + void const* const smem_ptr, + int32_t const& crd0, int32_t const& crd1) + { + return SM90_TMA_REDUCE_ADD_2D::copy(desc_ptr, smem_ptr, crd0, crd1); + } + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, + void const* const smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2) + { + return SM90_TMA_REDUCE_ADD_3D::copy(desc_ptr, smem_ptr, crd0, crd1, crd2); + } + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, + void const* const smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3) + { + return SM90_TMA_REDUCE_ADD_4D::copy(desc_ptr, smem_ptr, crd0, crd1, crd2, crd3); + } + CUTE_HOST_DEVICE static void + copy(void const* const desc_ptr, + void const* const smem_ptr, + int32_t const& crd0, int32_t const& crd1, int32_t const& crd2, int32_t const& crd3, int32_t const& crd4) + { + return SM90_TMA_REDUCE_ADD_5D::copy(desc_ptr, smem_ptr, crd0, crd1, crd2, crd3, crd4); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +/// BULK_COPY : Copy a bulk of memory between shared memory and global memory +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM90_BULK_COPY_G2S +{ + CUTE_HOST_DEVICE static void + copy(void const* gmem_ptr, uint64_t* mbar_ptr, + void * smem_ptr, int32_t load_bytes) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint32_t smem_int_mbar = cast_smem_ptr_to_uint(mbar_ptr); + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile("cp.async.bulk.shared::cluster.global.mbarrier::complete_tx::bytes [%0], [%1], %2, [%3];\n" + : + : "r"(smem_int_ptr), "l"(gmem_ptr), "r"(load_bytes), "r"(smem_int_mbar) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use BULK_COPY without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } + + struct PREFETCH + { + CUTE_HOST_DEVICE static void + copy(void const* gmem_ptr, int32_t load_bytes) + { + #if defined(CUTE_ARCH_TMA_SM90_ENABLED) + asm volatile("cp.async.bulk.prefetch.L2.global [%0], %1;\n" + : + : "l"(gmem_ptr), "r"(load_bytes) + : "memory"); + #else + CUTE_INVALID_CONTROL_PATH("Trying to use BULK_COPY without CUTE_ARCH_TMA_SM90_ENABLED."); + #endif + } + }; +}; + +struct SM90_BULK_COPY_S2G +{ + CUTE_HOST_DEVICE static void + copy(void const* smem_ptr, + void * gmem_ptr, int32_t store_bytes) + { +#if defined(CUTE_ARCH_TMA_SM90_ENABLED) + uint32_t smem_int_ptr = cast_smem_ptr_to_uint(smem_ptr); + asm volatile("cp.async.bulk.global.shared::cta.bulk_group [%0], [%1], %2;\n" + : + : "l"(gmem_ptr), "r"(smem_int_ptr), "r"(store_bytes) + : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Trying to use BULK_COPY without CUTE_ARCH_TMA_SM90_ENABLED."); +#endif + } +}; + +struct SM90_BULK_COPY_AUTO {}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // end namespace cute diff --git a/include/cute/arch/mma.hpp b/include/cute/arch/mma.hpp new file mode 100644 index 0000000000..6e06114a6c --- /dev/null +++ b/include/cute/arch/mma.hpp @@ -0,0 +1,64 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTE_HOST_DEVICE +#include // cute::fma +#include // cute::fma + +namespace cute +{ + +// +// Direct FMA for any type +// + +template +struct UniversalFMA +{ + using DRegisters = D[1]; + using ARegisters = A[1]; + using BRegisters = B[1]; + using CRegisters = C[1]; + + CUTE_HOST_DEVICE static constexpr void + fma(D & d, + A const& a, + B const& b, + C const& c) + { + // Forward to an ADL/cute free function for these types + using cute::fma; + fma(d, a, b, c); + } +}; + +} // end namespace cute diff --git a/include/cute/arch/mma_sm61.hpp b/include/cute/arch/mma_sm61.hpp new file mode 100644 index 0000000000..f7bcb7d19d --- /dev/null +++ b/include/cute/arch/mma_sm61.hpp @@ -0,0 +1,87 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include + +// Config +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 610)) +# define CUTE_ARCH_MMA_SM61_ENABLED +#endif + +namespace cute +{ + +struct SM61_DP4A +{ + using DRegisters = int32_t[1]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = int32_t[1]; + + // Register asm fma + CUTE_HOST_DEVICE static void + fma(int32_t& d, uint32_t const& a, uint32_t const& b, int32_t const& c) + { +#if defined(CUTE_ARCH_MMA_SM61_ENABLED) + asm volatile("dp4a.s32.s32 %0, %1, %2, %3;" + : "=r"(d) + : "r"(a), "r"(b), "r"(c)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM61_DP4A without CUTE_ARCH_MMA_SM61_ENABLED"); +#endif + } +}; + +struct SM61_DP2A +{ + using DRegisters = int32_t[1]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = int32_t[1]; + + // Register asm fma + CUTE_HOST_DEVICE static void + fma(int32_t& d, uint32_t const& a, uint32_t const& b, int32_t const& c) + { +#if defined(CUTE_ARCH_MMA_SM61_ENABLED) + asm volatile("dp2a.s32.s32 %0, %1, %2, %3;" + : "=r"(d) + : "r"(a), "r"(b), "r"(c)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM61_DP2A without CUTE_ARCH_MMA_SM61_ENABLED"); +#endif + } +}; + +} // namespace cute diff --git a/include/cute/arch/mma_sm70.hpp b/include/cute/arch/mma_sm70.hpp new file mode 100644 index 0000000000..63d96cf5d6 --- /dev/null +++ b/include/cute/arch/mma_sm70.hpp @@ -0,0 +1,329 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +// Config +#if ((__CUDACC_VER_MAJOR__ > 10) || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 1)) +# define CUTE_ARCH_MMA_SM70_SUPPORTED +# if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700)) +# define CUTE_ARCH_MMA_SM70_ENABLED +# endif +#endif + +namespace cute +{ + +// +// SM70 MMA 884 F16F16F16 +// + +struct SM70_8x8x4_F16F16F16F16_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + // Register asm fma + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM70_ENABLED) + asm volatile("mma.sync.aligned.m8n8k4.row.col.f16.f16.f16.f16" + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6, %7}," + "{%8, %9, %10, %11};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM70_8x8x4_F16F16F16F16_TN without CUTE_ARCH_MMA_SM70_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM70_8x8x4_F16F16F16F16_NT +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + // Register asm fma + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM70_ENABLED) + asm volatile("mma.sync.aligned.m8n8k4.col.row.f16.f16.f16.f16" + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6, %7}," + "{%8, %9, %10, %11};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM70_8x8x4_F16F16F16F16_NT without CUTE_ARCH_MMA_SM70_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM70_8x8x4_F16F16F16F16_NN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + // Register asm fma + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM70_ENABLED) + asm volatile("mma.sync.aligned.m8n8k4.col.col.f16.f16.f16.f16" + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6, %7}," + "{%8, %9, %10, %11};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM70_8x8x4_F16F16F16F16_NN without CUTE_ARCH_MMA_SM70_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM70_8x8x4_F16F16F16F16_TT +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + // Register asm fma + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM70_ENABLED) + asm volatile("mma.sync.aligned.m8n8k4.row.row.f16.f16.f16.f16" + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6, %7}," + "{%8, %9, %10, %11};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM70_8x8x4_F16F16F16F16_TT without CUTE_ARCH_MMA_SM70_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// SM70 MMA 884 F16F16F32 +// + +struct SM70_8x8x4_F32F16F16F32_TN +{ + using DRegisters = float[8]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[2]; + using CRegisters = float[8]; + + // Register asm fma + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, uint32_t const& b1, + float const& c0, float const& c1, float const& c2, float const& c3, + float const& c4, float const& c5, float const& c6, float const& c7) + { +#if defined(CUTE_ARCH_MMA_SM70_ENABLED) + asm volatile("mma.sync.aligned.m8n8k4.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11}," + "{%12, %13, %14, %15, %16, %17, %18, %19};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3), + "=f"(d4), "=f"(d5), "=f"(d6), "=f"(d7) + : "r"(a0), "r"(a1), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "f"(c4), "f"(c5), "f"(c6), "f"(c7)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM70_8x8x4_F32F16F16F32_TN without CUTE_ARCH_MMA_SM70_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM70_8x8x4_F32F16F16F32_NT +{ + using DRegisters = float[8]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[2]; + using CRegisters = float[8]; + + // Register asm fma + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, uint32_t const& b1, + float const& c0, float const& c1, float const& c2, float const& c3, + float const& c4, float const& c5, float const& c6, float const& c7) + { +#if defined(CUTE_ARCH_MMA_SM70_ENABLED) + asm volatile("mma.sync.aligned.m8n8k4.col.row.f32.f16.f16.f32" + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11}," + "{%12, %13, %14, %15, %16, %17, %18, %19};" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3), + "=f"(d4), "=f"(d5), "=f"(d6), "=f"(d7) + : "r"(a0), "r"(a1), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "f"(c4), "f"(c5), "f"(c6), "f"(c7)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM70_8x8x4_F32F16F16F32_NT without CUTE_ARCH_MMA_SM70_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM70_8x8x4_F32F16F16F32_NN +{ + using DRegisters = float[8]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[2]; + using CRegisters = float[8]; + + // Register asm fma + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, uint32_t const& b1, + float const& c0, float const& c1, float const& c2, float const& c3, + float const& c4, float const& c5, float const& c6, float const& c7) + { +#if defined(CUTE_ARCH_MMA_SM70_ENABLED) + asm volatile("mma.sync.aligned.m8n8k4.col.col.f32.f16.f16.f32" + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11}," + "{%12, %13, %14, %15, %16, %17, %18, %19};" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3), + "=f"(d4), "=f"(d5), "=f"(d6), "=f"(d7) + : "r"(a0), "r"(a1), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "f"(c4), "f"(c5), "f"(c6), "f"(c7)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM70_8x8x4_F32F16F16F32_NN without CUTE_ARCH_MMA_SM70_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +struct SM70_8x8x4_F32F16F16F32_TT +{ + using DRegisters = float[8]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[2]; + using CRegisters = float[8]; + + // Register asm fma + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, uint32_t const& b1, + float const& c0, float const& c1, float const& c2, float const& c3, + float const& c4, float const& c5, float const& c6, float const& c7) + { +#if defined(CUTE_ARCH_MMA_SM70_ENABLED) + asm volatile("mma.sync.aligned.m8n8k4.row.row.f32.f16.f16.f32" + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11}," + "{%12, %13, %14, %15, %16, %17, %18, %19};" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3), + "=f"(d4), "=f"(d5), "=f"(d6), "=f"(d7) + : "r"(a0), "r"(a1), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3), + "f"(c4), "f"(c5), "f"(c6), "f"(c7)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM70_8x8x4_F32F16F16F32_TT without CUTE_ARCH_MMA_SM70_ENABLED"); +#endif + } + +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // end namespace cute diff --git a/include/cute/arch/mma_sm75.hpp b/include/cute/arch/mma_sm75.hpp new file mode 100644 index 0000000000..c33f7b391c --- /dev/null +++ b/include/cute/arch/mma_sm75.hpp @@ -0,0 +1,120 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +// Config +#if ((__CUDACC_VER_MAJOR__ > 10) || (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2)) +# define CUTE_ARCH_MMA_SM75_SUPPORTED +# if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750)) +# define CUTE_ARCH_MMA_SM75_ENABLED +# endif +#endif + +namespace cute +{ + +// +// SM75 MMA 1688 F16F16F32 +// + +struct SM75_16x8x8_F32F16F16F32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = float[4]; + + // Register asm fma + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + float const& c0, float const& c1, float const& c2, float const& c3) + { +#if defined(CUTE_ARCH_MMA_SM75_ENABLED) + asm volatile("mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32" + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM75_16x8x8_F32F16F16F32_TN without CUTE_ARCH_MMA_SM75_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// +// SM75 MMA 8816 S8S8S32 +// + +struct SM75_8x8x16_S32S8S8S32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + // Register asm fma + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM75_ENABLED) + asm volatile("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32" + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM75_8x8x16_S32S8S8S32_TN without CUTE_ARCH_MMA_SM75_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // end namespace cute diff --git a/include/cute/arch/mma_sm80.hpp b/include/cute/arch/mma_sm80.hpp new file mode 100644 index 0000000000..17860dd40f --- /dev/null +++ b/include/cute/arch/mma_sm80.hpp @@ -0,0 +1,2241 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include +#include + +// Config +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) +# define CUTE_ARCH_MMA_SM80_ENABLED + +#if (__CUDA_ARCH__ <= 900) +#define CUTE_ARCH_MMA_B1_AND_SM80_ENABLED +#endif + +#if (__CUDA_ARCH__ <= 890) +#define CUTE_ARCH_MMA_B1_XOR_SM80_ENABLED +#endif + +#endif + + + +namespace cute { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x8 TN +struct SM80_16x8x8_F16F16F16F16_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f16.f16.f16.f16 " + "{%0, %1}," + "{%2, %3}," + "{%4}," + "{%5, %6};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x8_F16F16F16F16_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x16 TN +struct SM80_16x8x16_F16F16F16F16_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f16.f16.f16.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + "{%6, %7}," + "{%8, %9};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x16_F16F16F16F16_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x8 TN +struct SM80_16x8x8_F32F16F16F32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x8_F32F16F16F32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x16 TN +struct SM80_16x8x16_F32F16F16F32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.f16.f16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x16_F32F16F16F32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x8 TN +struct SM80_16x8x8_F32BF16BF16F32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.bf16.bf16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x8_F32BF16BF16F32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x16 TN +struct SM80_16x8x16_F32BF16BF16F32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f32.bf16.bf16.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x16_F32BF16BF16F32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x4 TN +struct SM80_16x8x4_F32TF32TF32F32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k4.row.col.f32.tf32.tf32.f32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x4_F32TF32TF32F32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x8 TN +struct SM80_16x8x8_F32TF32TF32F32_TN +{ + using DRegisters = float[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(float & d0, float & d1, float & d2, float & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + float const & c0, float const & c1, float const & c2, float const & c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f32.tf32.tf32.f32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=f"(d0), "=f"(d1), "=f"(d2), "=f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "f"(c0), "f"(c1), "f"(c2), "f"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x8_F32TF32TF32F32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x4 TN +struct SM80_8x8x4_F64F64F64F64_TN +{ + using DRegisters = double[2]; + using ARegisters = double[1]; + using BRegisters = double[1]; + using CRegisters = double[2]; + + CUTE_HOST_DEVICE static void + fma(double & d0, double & d1, + double const& a0, + double const& b0, + double const& c0, double const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k4.row.col.f64.f64.f64.f64 " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=d"(d0), "=d"(d1) + : "d"(a0), + "d"(b0), + "d"(c0), "d"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x4_F64F64F64F64_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +// MMA 8x8x4 TN with Planar Complex multiplication +struct SM80_8x8x4_C64C64C64C64_TN +{ + using DRegisters = complex[2]; + using ARegisters = complex[1]; + using BRegisters = complex[1]; + using CRegisters = complex[2]; + + CUTE_HOST_DEVICE static void + fma(complex & d0, complex & d1, + complex const& a0, + complex const& b0, + complex const& c0, complex const& c1) + { + // Because thrust::complex does not provide a mutable ref + double& rd0 = reinterpret_cast(d0)[0]; + double& id0 = reinterpret_cast(d0)[1]; + double& rd1 = reinterpret_cast(d1)[0]; + double& id1 = reinterpret_cast(d1)[1]; + + // d.real() = a.real() * b.real() + c.real(); + SM80_8x8x4_F64F64F64F64_TN::fma( + rd0, rd1, + a0.real(), + b0.real(), + c0.real(), c1.real()); + + // d.imag() = a.imag() * b.real() + c.imag(); + SM80_8x8x4_F64F64F64F64_TN::fma( + id0, id1, + a0.imag(), + b0.real(), + c0.imag(), c1.imag()); + + // d.real() = -a.imag() * b.imag() + d.real(); + SM80_8x8x4_F64F64F64F64_TN::fma( + rd0, rd1, + -a0.imag(), + b0.imag(), + d0.real(), d1.real()); + + // d.imag() = a.real() * b.imag() + d.imag(); + SM80_8x8x4_F64F64F64F64_TN::fma( + id0, id1, + a0.real(), + b0.imag(), + d0.imag(), d1.imag()); + } +}; + +// MMA 8x8x4 TN with Gaussian Complex multiplication: +// (a + bi)*(c + di) +// yields +// t0 += a*c +// t1 += b*d +// t2 += (a+b)*(c+d) +// then +// re = t0 - t1 +// im = t2 - t0 - t1 +struct SM80_8x8x4_GC64C64C64GC64_TN +{ + struct GaussComplex { + double t0, t1, t2; + + CUTE_HOST_DEVICE //constexpr + operator complex() const { return complex(t0 - t1, t2 - t0 - t1); } + + CUTE_HOST_DEVICE friend //constexpr + complex operator*(GaussComplex const& a, complex const& b) { return static_cast>(a) * b; } + CUTE_HOST_DEVICE friend //constexpr + complex operator*(complex const& a, GaussComplex const& b) { return b * a; } + + CUTE_HOST_DEVICE friend //constexpr + complex operator+(GaussComplex const& a, complex const& b) { return static_cast>(a) + b; } + CUTE_HOST_DEVICE friend //constexpr + complex operator+(complex const& a, GaussComplex const& b) { return b + a; } + }; + + using DRegisters = GaussComplex[2]; + using ARegisters = complex[1]; + using BRegisters = complex[1]; + using CRegisters = GaussComplex[2]; + + CUTE_HOST_DEVICE static void + fma(GaussComplex & d0, GaussComplex & d1, + complex const& a0, + complex const& b0, + GaussComplex const& c0, GaussComplex const& c1) + { + SM80_8x8x4_F64F64F64F64_TN::fma(d0.t0, d1.t0, + a0.real(), + b0.real(), + c0.t0, c1.t0); + SM80_8x8x4_F64F64F64F64_TN::fma(d0.t1, d1.t1, + a0.imag(), + b0.imag(), + c0.t1, c1.t1); + SM80_8x8x4_F64F64F64F64_TN::fma(d0.t2, d1.t2, + a0.real() + a0.imag(), + b0.real() + b0.imag(), + c0.t2, c1.t2); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x16 TN +struct SM80_8x8x16_S32S8S8S32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x16_S32S8S8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x16 TN +struct SM80_8x8x16_S32S8S8S32_TN_SATURATE +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32.satfinite " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x16_S32S8S8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x16 TN +struct SM80_16x8x16_S32S8S8S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x16_S32S8S8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x16 TN +struct SM80_16x8x16_S32S8S8S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x16_S32S8S8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32S8S8S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32S8S8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32S8S8S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32S8S8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x16 TN +struct SM80_8x8x16_S32S8U8S32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k16.row.col.s32.s8.u8.s32 " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x16_S32S8U8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x16 TN +struct SM80_8x8x16_S32S8U8S32_TN_SATURATE +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k16.row.col.s32.s8.u8.s32.satfinite " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x16_S32S8U8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x16 TN +struct SM80_16x8x16_S32S8U8S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.s8.u8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x16_S32S8U8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x16 TN +struct SM80_16x8x16_S32S8U8S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.s8.u8.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x16_S32S8U8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32S8U8S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.u8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32S8U8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32S8U8S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.u8.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32S8U8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x16 TN +struct SM80_8x8x16_S32U8S8S32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k16.row.col.s32.u8.s8.s32 " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x16_S32U8S8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x16 TN +struct SM80_8x8x16_S32U8S8S32_TN_SATURATE +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k16.row.col.s32.u8.s8.s32.satfinite " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x16_S32U8S8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x16 TN +struct SM80_16x8x16_S32U8S8S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.u8.s8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x16_S32U8S8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x16 TN +struct SM80_16x8x16_S32U8S8S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.u8.s8.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x16_S32U8S8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32U8S8S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.u8.s8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32U8S8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32U8S8S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.u8.s8.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32U8S8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x16 TN +struct SM80_8x8x16_S32U8U8S32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k16.row.col.s32.u8.u8.s32 " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x16_S32U8U8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x16 TN +struct SM80_8x8x16_S32U8U8S32_TN_SATURATE +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k16.row.col.s32.u8.u8.s32.satfinite " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x16_S32U8U8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x16 TN +struct SM80_16x8x16_S32U8U8S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.u8.u8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x16_S32U8U8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x16 TN +struct SM80_16x8x16_S32U8U8S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.s32.u8.u8.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x16_S32U8U8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32U8U8S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.u8.u8.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32U8U8S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32U8U8S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.u8.u8.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32U8U8S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x32 TN +struct SM80_8x8x32_S32S4S4S32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k32.row.col.s32.s4.s4.s32 " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x32_S32S4S4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x32 TN +struct SM80_8x8x32_S32S4S4S32_TN_SATURATE +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k32.row.col.s32.s4.s4.s32.satfinite " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x32_S32S4S4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32S4S4S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s4.s4.s32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32S4S4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32S4S4S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s4.s4.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32S4S4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN +struct SM80_16x8x64_S32S4S4S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x64_S32S4S4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN +struct SM80_16x8x64_S32S4S4S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x64_S32S4S4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x32 TN +struct SM80_8x8x32_S32S4U4S32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k32.row.col.s32.s4.u4.s32 " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x32_S32S4U4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x32 TN +struct SM80_8x8x32_S32S4U4S32_TN_SATURATE +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k32.row.col.s32.s4.u4.s32.satfinite " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x32_S32S4U4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32S4U4S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s4.u4.s32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32S4U4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32S4U4S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s4.u4.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32S4U4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN +struct SM80_16x8x64_S32S4U4S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.s4.u4.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x64_S32S4U4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN +struct SM80_16x8x64_S32S4U4S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.s4.u4.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x64_S32S4U4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x32 TN +struct SM80_8x8x32_S32U4S4S32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k32.row.col.s32.u4.s4.s32 " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x32_S32U4S4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x32 TN +struct SM80_8x8x32_S32U4S4S32_TN_SATURATE +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k32.row.col.s32.u4.s4.s32.satfinite " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x32_S32U4S4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32U4S4S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.u4.s4.s32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32U4S4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32U4S4S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.u4.s4.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32U4S4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN +struct SM80_16x8x64_S32U4S4S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.u4.s4.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x64_S32U4S4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN +struct SM80_16x8x64_S32U4S4S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.u4.s4.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x64_S32U4S4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x32 TN +struct SM80_8x8x32_S32U4U4S32_TN +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k32.row.col.s32.u4.u4.s32 " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x32_S32U4U4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x32 TN +struct SM80_8x8x32_S32U4U4S32_TN_SATURATE +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k32.row.col.s32.u4.u4.s32.satfinite " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x32_S32U4U4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32U4U4S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.u4.u4.s32 " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32U4U4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x32 TN +struct SM80_16x8x32_S32U4U4S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.u4.u4.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x32_S32U4U4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN +struct SM80_16x8x64_S32U4U4S32_TN +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.u4.u4.s32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x64_S32U4U4S32_TN without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x64 TN +struct SM80_16x8x64_S32U4U4S32_TN_SATURATE +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.u4.u4.s32.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x64_S32U4U4S32_TN_SATURATE without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x128 TN +struct SM80_8x8x128_S32U1U1S32_TN_ANDPOPC +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k128.row.col.s32.b1.b1.s32.and.popc " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x128_S32U1U1S32_TN_ANDPOPC without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x128 TN +struct SM80_16x8x128_S32U1U1S32_TN_ANDPOPC +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k128.row.col.s32.b1.b1.s32.and.popc " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x128_S32U1U1S32_TN_ANDPOPC without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x256 TN +struct SM80_16x8x256_S32U1U1S32_TN_ANDPOPC +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_B1_AND_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.and.popc " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x256_S32U1U1S32_TN_ANDPOPC without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 8x8x128 TN +struct SM80_8x8x128_S32U1U1S32_TN_XORPOPC +{ + using DRegisters = uint32_t[2]; + using ARegisters = uint32_t[1]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, + uint32_t const& a0, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1) + { +#if defined(CUTE_ARCH_MMA_B1_XOR_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m8n8k128.row.col.s32.b1.b1.s32.xor.popc " + "{%0, %1}," + "{%2}," + "{%3}," + "{%4, %5};\n" + : "=r"(d0), "=r"(d1) + : "r"(a0), + "r"(b0), + "r"(c0), "r"(c1)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_8x8x128_S32U1U1S32_TN_XORPOPC without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x128 TN +struct SM80_16x8x128_S32U1U1S32_TN_XORPOPC +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[2]; + using BRegisters = uint32_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, + uint32_t const& b0, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_B1_XOR_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k128.row.col.s32.b1.b1.s32.xor.popc " + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), + "r"(b0), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x128_S32U1U1S32_TN_XORPOPC without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x256 TN +struct SM80_16x8x256_S32U1U1S32_TN_XORPOPC +{ + using DRegisters = uint32_t[4]; + using ARegisters = uint32_t[4]; + using BRegisters = uint32_t[2]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint32_t const& b0, uint32_t const& b1, + uint32_t const& c0, uint32_t const& c1, uint32_t const& c2, uint32_t const& c3) + { +#if defined(CUTE_ARCH_MMA_B1_XOR_SM80_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=r"(d0), "=r"(d1), "=r"(d2), "=r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "r"(b0), "r"(b1), + "r"(c0), "r"(c1), "r"(c2), "r"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM80_16x8x256_S32U1U1S32_TN_XORPOPC without CUTE_ARCH_MMA_SM80_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cute diff --git a/include/cute/arch/mma_sm90.hpp b/include/cute/arch/mma_sm90.hpp new file mode 100644 index 0000000000..51d34563c4 --- /dev/null +++ b/include/cute/arch/mma_sm90.hpp @@ -0,0 +1,9331 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include + +// Config +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) +# define CUTE_ARCH_MMA_SM90_ENABLED +# define CUTE_ARCH_MMA_F64_SM90_ENABLED +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cute { + +namespace SM90 { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x4 TN +struct MMA_16x8x4_F64F64F64F64_TN +{ + using DRegisters = double[4]; + using ARegisters = double[2]; + using BRegisters = double[1]; + using CRegisters = double[4]; + + CUTE_HOST_DEVICE static void + fma(double & d0, double & d1, double & d2, double & d3, + double const& a0, double const& a1, + double const& b0, + double const& c0, double const& c1, double const& c2, double const& c3) + { +#if defined(CUTE_ARCH_MMA_F64_SM90_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k4.row.col.f64.f64.f64.f64" + "{%0, %1, %2, %3}," + "{%4, %5}," + "{%6}," + "{%7, %8, %9, %10};\n" + : "=d"(d0), "=d"(d1), "=d"(d2), "=d"(d3) + : "d"(a0), "d"(a1), + "d"(b0), + "d"(c0), "d"(c1), "d"(c2), "d"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_16x8x4_F64F64F64F64_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x8 TN +struct MMA_16x8x8_F64F64F64F64_TN +{ + using DRegisters = double[4]; + using ARegisters = double[4]; + using BRegisters = double[2]; + using CRegisters = double[4]; + + CUTE_HOST_DEVICE static void + fma(double & d0, double & d1, double & d2, double & d3, + double const& a0, double const& a1, double const& a2, double const& a3, + double const& b0, double const& b1, + double const& c0, double const& c1, double const& c2, double const& c3) + { +#if defined(CUTE_ARCH_MMA_F64_SM90_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k8.row.col.f64.f64.f64.f64" + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + "{%8, %9}," + "{%10, %11, %12, %13};\n" + : "=d"(d0), "=d"(d1), "=d"(d2), "=d"(d3) + : "d"(a0), "d"(a1), "d"(a2), "d"(a3), + "d"(b0), "d"(b1), + "d"(c0), "d"(c1), "d"(c2), "d"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_16x8x8_F64F64F64F64_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x16 TN +struct MMA_16x8x16_F64F64F64F64_TN +{ + using DRegisters = double[4]; + using ARegisters = double[8]; + using BRegisters = double[4]; + using CRegisters = double[4]; + + CUTE_HOST_DEVICE static void + fma(double & d0, double & d1, double & d2, double & d3, + double const& a0, double const& a1, double const& a2, double const& a3, + double const& a4, double const& a5, double const& a6, double const& a7, + double const& b0, double const& b1, double const& b2, double const& b3, + double const& c0, double const& c1, double const& c2, double const& c3) + { +#if defined(CUTE_ARCH_MMA_F64_SM90_ENABLED) + asm volatile( + "mma.sync.aligned.m16n8k16.row.col.f64.f64.f64.f64" + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7, %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + "{%16, %17, %18, %19};\n" + : "=d"(d0), "=d"(d1), "=d"(d2), "=d"(d3) + : "d"(a0), "d"(a1), "d"(a2), "d"(a3), + "d"(a4), "d"(a5), "d"(a6), "d"(a7), + "d"(b0), "d"(b1), "d"(b2), "d"(b3), + "d"(c0), "d"(c1), "d"(c2), "d"(c3)); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_16x8x16_F64F64F64F64_TN without CUTE_ARCH_MMA_SM90_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x4 TN +struct MMA_16x8x4_C64C64C64C64_TN +{ + using DRegisters = complex[4]; + using ARegisters = complex[2]; + using BRegisters = complex[1]; + using CRegisters = complex[4]; + + CUTE_HOST_DEVICE static void + fma(complex & d0, complex & d1, + complex & d2, complex & d3, + complex const& a0, complex const& a1, + complex const& b0, + complex const& c0, complex const& c1, + complex const& c2, complex const& c3) + { + // Because thrust::complex does not provide a mutable ref + double& rd0 = reinterpret_cast(d0)[0]; + double& id0 = reinterpret_cast(d0)[1]; + double& rd1 = reinterpret_cast(d1)[0]; + double& id1 = reinterpret_cast(d1)[1]; + double& rd2 = reinterpret_cast(d2)[0]; + double& id2 = reinterpret_cast(d2)[1]; + double& rd3 = reinterpret_cast(d3)[0]; + double& id3 = reinterpret_cast(d3)[1]; + + // d.real() = a.real() * b.real() + c.real(); + MMA_16x8x4_F64F64F64F64_TN::fma( + rd0, rd1, rd2, rd3, + a0.real(), a1.real(), + b0.real(), + c0.real(), c1.real(), c2.real(), c3.real()); + + // d.imag() = a.imag() * b.real() + c.imag(); + MMA_16x8x4_F64F64F64F64_TN::fma( + id0, id1, id2, id3, + a0.imag(), a1.imag(), + b0.real(), + c0.imag(), c1.imag(), c2.imag(), c3.imag()); + + // d.real() = -a.imag() * b.imag() + d.real(); + MMA_16x8x4_F64F64F64F64_TN::fma( + rd0, rd1, rd2, rd3, + -a0.imag(), -a1.imag(), + b0.imag(), + d0.real(), d1.real(), d2.real(), d3.real()); + + // d.imag() = a.real() * b.imag() + d.imag(); + MMA_16x8x4_F64F64F64F64_TN::fma( + id0, id1, id2, id3, + a0.real(), a1.real(), + b0.imag(), + d0.imag(), d1.imag(), d2.imag(), d3.imag()); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x8 TN +struct MMA_16x8x8_C64C64C64C64_TN +{ + using DRegisters = complex[4]; + using ARegisters = complex[4]; + using BRegisters = complex[2]; + using CRegisters = complex[4]; + + CUTE_HOST_DEVICE static void + fma(complex & d0, complex & d1, + complex & d2, complex & d3, + complex const& a0, complex const& a1, + complex const& a2, complex const& a3, + complex const& b0, complex const& b1, + complex const& c0, complex const& c1, + complex const& c2, complex const& c3) + { + // Because thrust::complex does not provide a mutable ref + double& rd0 = reinterpret_cast(d0)[0]; + double& id0 = reinterpret_cast(d0)[1]; + double& rd1 = reinterpret_cast(d1)[0]; + double& id1 = reinterpret_cast(d1)[1]; + double& rd2 = reinterpret_cast(d2)[0]; + double& id2 = reinterpret_cast(d2)[1]; + double& rd3 = reinterpret_cast(d3)[0]; + double& id3 = reinterpret_cast(d3)[1]; + + // d.real() = a.real() * b.real() + c.real(); + MMA_16x8x8_F64F64F64F64_TN::fma( + rd0, rd1, rd2, rd3, + a0.real(), a1.real(), a2.real(), a3.real(), + b0.real(), b1.real(), + c0.real(), c1.real(), c2.real(), c3.real()); + + // d.imag() = a.imag() * b.real() + c.imag(); + MMA_16x8x8_F64F64F64F64_TN::fma( + id0, id1, id2, id3, + a0.imag(), a1.imag(), a2.imag(), a3.imag(), + b0.real(), b1.real(), + c0.imag(), c1.imag(), c2.imag(), c3.imag()); + + // d.real() = -a.imag() * b.imag() + d.real(); + MMA_16x8x8_F64F64F64F64_TN::fma( + rd0, rd1, rd2, rd3, + -a0.imag(), -a1.imag(), -a2.imag(), -a3.imag(), + b0.imag(), b1.imag(), + d0.real(), d1.real(), d2.real(), d3.real()); + + // d.imag() = a.real() * b.imag() + d.imag(); + MMA_16x8x8_F64F64F64F64_TN::fma( + id0, id1, id2, id3, + a0.real(), a1.real(), a2.real(), a3.real(), + b0.imag(), b1.imag(), + d0.imag(), d1.imag(), d2.imag(), d3.imag()); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// MMA 16x8x16 TN +struct MMA_16x8x16_C64C64C64C64_TN +{ + using DRegisters = complex[4]; + using ARegisters = complex[8]; + using BRegisters = complex[4]; + using CRegisters = complex[4]; + + CUTE_HOST_DEVICE static void + fma(complex & d0, complex & d1, + complex & d2, complex & d3, + complex const& a0, complex const& a1, + complex const& a2, complex const& a3, + complex const& a4, complex const& a5, + complex const& a6, complex const& a7, + complex const& b0, complex const& b1, + complex const& b2, complex const& b3, + complex const& c0, complex const& c1, + complex const& c2, complex const& c3) + { + // Because thrust::complex does not provide a mutable ref + double& rd0 = reinterpret_cast(d0)[0]; + double& id0 = reinterpret_cast(d0)[1]; + double& rd1 = reinterpret_cast(d1)[0]; + double& id1 = reinterpret_cast(d1)[1]; + double& rd2 = reinterpret_cast(d2)[0]; + double& id2 = reinterpret_cast(d2)[1]; + double& rd3 = reinterpret_cast(d3)[0]; + double& id3 = reinterpret_cast(d3)[1]; + + // d.real() = a.real() * b.real() + c.real(); + MMA_16x8x16_F64F64F64F64_TN::fma( + rd0, rd1, rd2, rd3, + a0.real(), a1.real(), a2.real(), a3.real(), + a4.real(), a5.real(), a6.real(), a7.real(), + b0.real(), b1.real(), b2.real(), b3.real(), + c0.real(), c1.real(), c2.real(), c3.real()); + + // d.imag() = a.imag() * b.real() + c.imag(); + MMA_16x8x16_F64F64F64F64_TN::fma( + id0, id1, id2, id3, + a0.imag(), a1.imag(), a2.imag(), a3.imag(), + a4.imag(), a5.imag(), a6.imag(), a7.imag(), + b0.real(), b1.real(), b2.real(), b3.real(), + c0.imag(), c1.imag(), c2.imag(), c3.imag()); + + // d.real() = -a.imag() * b.imag() + d.real(); + MMA_16x8x16_F64F64F64F64_TN::fma( + rd0, rd1, rd2, rd3, + -a0.imag(), -a1.imag(), -a2.imag(), -a3.imag(), + -a4.imag(), -a5.imag(), -a6.imag(), -a7.imag(), + b0.imag(), b1.imag(), b2.imag(), b3.imag(), + d0.real(), d1.real(), d2.real(), d3.real()); + + // d.imag() = a.real() * b.imag() + d.imag(); + MMA_16x8x16_F64F64F64F64_TN::fma( + id0, id1, id2, id3, + a0.real(), a1.real(), a2.real(), a3.real(), + a4.real(), a5.real(), a6.real(), a7.real(), + b0.imag(), b1.imag(), b2.imag(), b3.imag(), + d0.imag(), d1.imag(), d2.imag(), d3.imag()); + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} + +} // namespace cute + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#include +#include +#include +#include // cute::size +#include // cute::is_static +#include // cute::half_t, cute::float_e4m3_t, cute::tfloat32_t, etc +#include // cute::is_same_v + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cute { +namespace SM90::GMMA { + +template < + class ElementA, + class ElementB, + class ElementC, + class TileShape_MNK, + GMMA::Major MajorA = GMMA::Major::K, + GMMA::Major MajorB = GMMA::Major::K, + auto... Args // e.g. GMMA::ScaleOut::One, [GMMA::ScaleIn::One, GMMA::ScaleIn::One] + // But most commonly leave empty for defaults +> +CUTE_HOST_DEVICE constexpr +auto +ss_op_selector() +{ + static_assert(is_static::value, "TileShape_MNK must be static."); + static_assert(rank(TileShape_MNK{}) == 3, "TileShape_MNK must be rank 3."); + static_assert(size<0>(TileShape_MNK{}) % 64 == 0, "Tile_M must be a multiple of 64."); + auto Tile_N = size<1>(TileShape_MNK{}); + + // F16 accumulator + if constexpr (is_same_v) { + + // Input A: half_t ; Input B: half_t + if constexpr (is_same_v && is_same_v) { + static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x16_F16F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x16_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x16_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x16_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x16_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x16_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x16_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x16_F16F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x16_F16F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x16_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x16_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x16_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x16_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x16_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x16_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x16_F16F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x16_F16F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x16_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x16_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x16_F16F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x16_F16F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x16_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x16_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x16_F16F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x16_F16F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x16_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x16_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x16_F16F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x16_F16F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x16_F16F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x16_F16F16F16_SS{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x16_F16F16F16_SS{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e4m3_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_F16E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x32_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x32_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x32_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x32_F16E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_F16E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x32_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x32_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x32_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x32_F16E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_F16E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x32_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x32_F16E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_F16E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x32_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x32_F16E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_F16E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x32_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x32_F16E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_F16E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_F16E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_F16E4M3E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_F16E4M3E4M3_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e4m3_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_F16E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x32_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x32_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x32_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x32_F16E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_F16E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x32_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x32_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x32_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x32_F16E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_F16E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x32_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x32_F16E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_F16E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x32_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x32_F16E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_F16E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x32_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x32_F16E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_F16E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_F16E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_F16E4M3E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_F16E4M3E5M2_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e5m2_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_F16E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x32_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x32_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x32_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x32_F16E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_F16E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x32_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x32_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x32_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x32_F16E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_F16E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x32_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x32_F16E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_F16E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x32_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x32_F16E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_F16E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x32_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x32_F16E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_F16E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_F16E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_F16E5M2E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_F16E5M2E4M3_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e5m2_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_F16E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x32_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x32_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x32_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x32_F16E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_F16E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x32_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x32_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x32_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x32_F16E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_F16E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x32_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x32_F16E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_F16E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x32_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x32_F16E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_F16E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x32_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x32_F16E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_F16E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_F16E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_F16E5M2E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_F16E5M2E5M2_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + else { + static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration."); + } + } + + // F32 accumulator + else if constexpr (is_same_v) { + + // Input A: half_t ; Input B: half_t + if constexpr (is_same_v && is_same_v) { + static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x16_F32F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x16_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x16_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x16_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x16_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x16_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x16_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x16_F32F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x16_F32F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x16_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x16_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x16_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x16_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x16_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x16_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x16_F32F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x16_F32F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x16_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x16_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x16_F32F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x16_F32F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x16_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x16_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x16_F32F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x16_F32F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x16_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x16_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x16_F32F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x16_F32F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x16_F32F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x16_F32F16F16_SS{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x16_F32F16F16_SS{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: bfloat16_t ; Input B: bfloat16_t + else if constexpr (is_same_v && is_same_v) { + static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x16_F32BF16BF16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x16_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x16_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x16_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x16_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x16_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x16_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x16_F32BF16BF16_SS{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x16_F32BF16BF16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x16_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x16_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x16_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x16_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x16_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x16_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x16_F32BF16BF16_SS{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x16_F32BF16BF16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x16_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x16_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x16_F32BF16BF16_SS{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x16_F32BF16BF16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x16_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x16_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x16_F32BF16BF16_SS{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x16_F32BF16BF16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x16_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x16_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x16_F32BF16BF16_SS{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x16_F32BF16BF16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x16_F32BF16BF16_SS{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x16_F32BF16BF16_SS{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x16_F32BF16BF16_SS{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: tfloat32_t ; Input B: tfloat32_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 8 == 0, "Tile_K must be a multiple of 8."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x8_F32TF32TF32_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x8_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x8_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x8_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x8_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x8_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x8_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x8_F32TF32TF32_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x8_F32TF32TF32_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x8_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x8_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x8_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x8_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x8_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x8_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x8_F32TF32TF32_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x8_F32TF32TF32_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x8_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x8_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x8_F32TF32TF32_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x8_F32TF32TF32_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x8_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x8_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x8_F32TF32TF32_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x8_F32TF32TF32_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x8_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x8_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x8_F32TF32TF32_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x8_F32TF32TF32_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x8_F32TF32TF32_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x8_F32TF32TF32_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x8_F32TF32TF32_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e4m3_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_F32E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x32_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x32_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x32_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x32_F32E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_F32E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x32_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x32_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x32_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x32_F32E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_F32E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x32_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x32_F32E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_F32E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x32_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x32_F32E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_F32E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x32_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x32_F32E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_F32E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_F32E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_F32E4M3E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_F32E4M3E4M3_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e4m3_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_F32E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x32_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x32_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x32_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x32_F32E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_F32E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x32_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x32_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x32_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x32_F32E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_F32E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x32_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x32_F32E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_F32E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x32_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x32_F32E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_F32E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x32_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x32_F32E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_F32E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_F32E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_F32E4M3E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_F32E4M3E5M2_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e5m2_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_F32E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x32_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x32_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x32_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x32_F32E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_F32E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x32_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x32_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x32_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x32_F32E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_F32E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x32_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x32_F32E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_F32E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x32_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x32_F32E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_F32E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x32_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x32_F32E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_F32E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_F32E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_F32E5M2E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_F32E5M2E4M3_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e5m2_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_F32E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x32_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x32_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x32_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x32_F32E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_F32E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x32_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x32_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x32_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x32_F32E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_F32E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x32_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x32_F32E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_F32E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x32_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x32_F32E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_F32E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x32_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x32_F32E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_F32E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_F32E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_F32E5M2E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_F32E5M2E5M2_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + else { + static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration."); + } + } + + // S32 accumulator + else if constexpr (is_same_v) { + + // Input A: int8_t ; Input B: int8_t + if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_S32S8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_S32S8S8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_S32S8S8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_S32S8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_S32S8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_S32S8S8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_S32S8S8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_S32S8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_S32S8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_S32S8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_S32S8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_S32S8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_S32S8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_S32S8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_S32S8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_S32S8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_S32S8S8_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_S32S8S8_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: int8_t ; Input B: uint8_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_S32S8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_S32S8U8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_S32S8U8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_S32S8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_S32S8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_S32S8U8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_S32S8U8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_S32S8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_S32S8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_S32S8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_S32S8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_S32S8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_S32S8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_S32S8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_S32S8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_S32S8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_S32S8U8_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_S32S8U8_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: uint8_t ; Input B: int8_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_S32U8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_S32U8S8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_S32U8S8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_S32U8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_S32U8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_S32U8S8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_S32U8S8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_S32U8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_S32U8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_S32U8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_S32U8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_S32U8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_S32U8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_S32U8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_S32U8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_S32U8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_S32U8S8_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_S32U8S8_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: uint8_t ; Input B: uint8_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_S32U8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_S32U8U8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_S32U8U8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_S32U8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_S32U8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_S32U8U8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_S32U8U8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_S32U8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_S32U8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_S32U8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_S32U8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_S32U8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_S32U8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_S32U8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_S32U8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_S32U8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_S32U8U8_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_S32U8U8_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + else { + static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration."); + } + } + + // Unknown accumulator type + else { + static_assert(sizeof(ElementC) == 0, "Unknown ElementC accumulator type."); + } +} + +template < + class ElementA, + class ElementB, + class ElementC, + class TileShape_MNK, + GMMA::Major MajorA = GMMA::Major::K, + GMMA::Major MajorB = GMMA::Major::K, + auto... Args // e.g. GMMA::ScaleOut::One, [GMMA::ScaleIn::One, GMMA::ScaleIn::One] + // But most commonly leave empty for defaults +> +CUTE_HOST_DEVICE constexpr +auto +ss_op_selector_sparse() +{ + static_assert(is_static::value, "TileShape_MNK must be static."); + static_assert(rank(TileShape_MNK{}) == 3, "TileShape_MNK must be rank 3."); + static_assert(size<0>(TileShape_MNK{}) % 64 == 0, "Tile_M must be a multiple of 64."); + auto Tile_N = size<1>(TileShape_MNK{}); + + // F16 accumulator + if constexpr (is_same_v) { + + // Input A: half_t ; Input B: half_t + if constexpr (is_same_v && is_same_v) { + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x32_F16F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x32_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x32_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x32_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x32_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x32_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x32_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x32_F16F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x32_F16F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x32_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x32_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x32_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x32_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x32_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x32_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x32_F16F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x32_F16F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x32_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x32_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x32_F16F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x32_F16F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x32_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x32_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x32_F16F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x32_F16F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x32_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x32_F16F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x32_F16F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x32_F16F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x32_F16F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x32_F16F16F16_SS{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x32_F16F16F16_SS{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e4m3_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_F16E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x64_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x64_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x64_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x64_F16E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_F16E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x64_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x64_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x64_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x64_F16E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_F16E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x64_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x64_F16E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_F16E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x64_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x64_F16E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_F16E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x64_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_F16E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x64_F16E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_F16E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_F16E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_F16E4M3E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_F16E4M3E4M3_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e4m3_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_F16E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x64_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x64_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x64_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x64_F16E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_F16E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x64_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x64_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x64_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x64_F16E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_F16E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x64_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x64_F16E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_F16E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x64_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x64_F16E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_F16E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x64_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_F16E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x64_F16E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_F16E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_F16E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_F16E4M3E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_F16E4M3E5M2_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e5m2_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_F16E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x64_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x64_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x64_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x64_F16E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_F16E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x64_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x64_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x64_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x64_F16E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_F16E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x64_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x64_F16E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_F16E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x64_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x64_F16E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_F16E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x64_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_F16E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x64_F16E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_F16E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_F16E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_F16E5M2E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_F16E5M2E4M3_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e5m2_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_F16E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x64_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x64_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x64_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x64_F16E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_F16E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x64_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x64_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x64_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x64_F16E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_F16E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x64_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x64_F16E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_F16E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x64_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x64_F16E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_F16E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x64_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_F16E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x64_F16E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_F16E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_F16E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_F16E5M2E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_F16E5M2E5M2_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + else { + static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration."); + } + } + + // F32 accumulator + else if constexpr (is_same_v) { + + // Input A: half_t ; Input B: half_t + if constexpr (is_same_v && is_same_v) { + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x32_F32F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x32_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x32_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x32_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x32_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x32_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x32_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x32_F32F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x32_F32F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x32_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x32_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x32_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x32_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x32_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x32_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x32_F32F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x32_F32F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x32_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x32_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x32_F32F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x32_F32F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x32_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x32_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x32_F32F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x32_F32F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x32_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x32_F32F16F16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x32_F32F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x32_F32F16F16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x32_F32F16F16_SS{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x32_F32F16F16_SS{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x32_F32F16F16_SS{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: bfloat16_t ; Input B: bfloat16_t + else if constexpr (is_same_v && is_same_v) { + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x32_F32BF16BF16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x32_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x32_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x32_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x32_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x32_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x32_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x32_F32BF16BF16_SS{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x32_F32BF16BF16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x32_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x32_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x32_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x32_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x32_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x32_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x32_F32BF16BF16_SS{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x32_F32BF16BF16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x32_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x32_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x32_F32BF16BF16_SS{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x32_F32BF16BF16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x32_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x32_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x32_F32BF16BF16_SS{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x32_F32BF16BF16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x32_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x32_F32BF16BF16_SS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x32_F32BF16BF16_SS{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x32_F32BF16BF16_SS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x32_F32BF16BF16_SS{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x32_F32BF16BF16_SS{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x32_F32BF16BF16_SS{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: tfloat32_t ; Input B: tfloat32_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x16_F32TF32TF32_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x16_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x16_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x16_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x16_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x16_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x16_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x16_F32TF32TF32_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x16_F32TF32TF32_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x16_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x16_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x16_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x16_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x16_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x16_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x16_F32TF32TF32_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x16_F32TF32TF32_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x16_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x16_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x16_F32TF32TF32_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x16_F32TF32TF32_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x16_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x16_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x16_F32TF32TF32_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x16_F32TF32TF32_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x16_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x16_F32TF32TF32_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x16_F32TF32TF32_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x16_F32TF32TF32_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x16_F32TF32TF32_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x16_F32TF32TF32_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x16_F32TF32TF32_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e4m3_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_F32E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x64_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x64_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x64_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x64_F32E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_F32E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x64_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x64_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x64_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x64_F32E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_F32E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x64_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x64_F32E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_F32E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x64_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x64_F32E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_F32E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x64_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_F32E4M3E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x64_F32E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_F32E4M3E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_F32E4M3E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_F32E4M3E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_F32E4M3E4M3_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e4m3_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_F32E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x64_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x64_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x64_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x64_F32E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_F32E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x64_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x64_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x64_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x64_F32E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_F32E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x64_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x64_F32E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_F32E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x64_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x64_F32E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_F32E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x64_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_F32E4M3E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x64_F32E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_F32E4M3E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_F32E4M3E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_F32E4M3E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_F32E4M3E5M2_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e5m2_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_F32E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x64_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x64_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x64_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x64_F32E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_F32E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x64_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x64_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x64_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x64_F32E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_F32E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x64_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x64_F32E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_F32E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x64_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x64_F32E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_F32E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x64_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_F32E5M2E4M3_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x64_F32E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_F32E5M2E4M3_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_F32E5M2E4M3_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_F32E5M2E4M3_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_F32E5M2E4M3_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e5m2_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_F32E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x64_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x64_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x64_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x64_F32E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_F32E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x64_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x64_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x64_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x64_F32E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_F32E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x64_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x64_F32E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_F32E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x64_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x64_F32E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_F32E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x64_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_F32E5M2E5M2_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x64_F32E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_F32E5M2E5M2_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_F32E5M2E5M2_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_F32E5M2E5M2_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_F32E5M2E5M2_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + else { + static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration."); + } + } + + // S32 accumulator + else if constexpr (is_same_v) { + + // Input A: int8_t ; Input B: int8_t + if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8S8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8S8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8S8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8S8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_S32S8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8S8_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8S8_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: int8_t ; Input B: uint8_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8U8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8U8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8U8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8U8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_S32S8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8U8_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8U8_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: uint8_t ; Input B: int8_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8S8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8S8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8S8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8S8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8S8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_S32U8S8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8S8_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8S8_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: uint8_t ; Input B: uint8_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8U8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8U8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8U8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8U8_SS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8U8_SS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_S32U8U8_SS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8U8_SS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8U8_SS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + else { + static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration."); + } + } + + // Unknown accumulator type + else { + static_assert(sizeof(ElementC) == 0, "Unknown ElementC accumulator type."); + } +} + +template < + class ElementA, + class ElementB, + class ElementC, + class TileShape_MNK, + GMMA::Major MajorA = GMMA::Major::K, + GMMA::Major MajorB = GMMA::Major::K, + auto... Args // e.g. GMMA::ScaleOut::One, [GMMA::ScaleIn::One, GMMA::ScaleIn::One] + // But most commonly leave empty for defaults +> +CUTE_HOST_DEVICE constexpr +auto +rs_op_selector() +{ + static_assert(is_static::value, "TileShape_MNK must be static."); + static_assert(rank(TileShape_MNK{}) == 3, "TileShape_MNK must be rank 3."); + static_assert(size<0>(TileShape_MNK{}) % 64 == 0, "Tile_M must be a multiple of 64."); + static_assert(MajorA == GMMA::Major::K, "Register source A operand GMMAs must have K-major A layout."); + auto Tile_N = size<1>(TileShape_MNK{}); + + // F16 accumulator + if constexpr (is_same_v) { + + // Input A: half_t ; Input B: half_t + if constexpr (is_same_v && is_same_v) { + static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x16_F16F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x16_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x16_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x16_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x16_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x16_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x16_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x16_F16F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x16_F16F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x16_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x16_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x16_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x16_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x16_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x16_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x16_F16F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x16_F16F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x16_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x16_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x16_F16F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x16_F16F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x16_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x16_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x16_F16F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x16_F16F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x16_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x16_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x16_F16F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x16_F16F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x16_F16F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x16_F16F16F16_RS{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x16_F16F16F16_RS{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e4m3_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_F16E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x32_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x32_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x32_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x32_F16E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_F16E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x32_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x32_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x32_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x32_F16E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_F16E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x32_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x32_F16E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_F16E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x32_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x32_F16E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_F16E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x32_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x32_F16E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_F16E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_F16E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_F16E4M3E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_F16E4M3E4M3_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e4m3_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_F16E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x32_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x32_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x32_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x32_F16E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_F16E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x32_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x32_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x32_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x32_F16E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_F16E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x32_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x32_F16E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_F16E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x32_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x32_F16E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_F16E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x32_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x32_F16E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_F16E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_F16E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_F16E4M3E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_F16E4M3E5M2_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e5m2_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_F16E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x32_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x32_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x32_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x32_F16E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_F16E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x32_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x32_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x32_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x32_F16E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_F16E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x32_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x32_F16E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_F16E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x32_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x32_F16E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_F16E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x32_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x32_F16E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_F16E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_F16E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_F16E5M2E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_F16E5M2E4M3_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e5m2_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_F16E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x32_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x32_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x32_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x32_F16E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_F16E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x32_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x32_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x32_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x32_F16E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_F16E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x32_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x32_F16E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_F16E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x32_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x32_F16E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_F16E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x32_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x32_F16E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_F16E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_F16E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_F16E5M2E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_F16E5M2E5M2_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + else { + static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration."); + } + } + + // F32 accumulator + else if constexpr (is_same_v) { + + // Input A: half_t ; Input B: half_t + if constexpr (is_same_v && is_same_v) { + static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x16_F32F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x16_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x16_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x16_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x16_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x16_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x16_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x16_F32F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x16_F32F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x16_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x16_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x16_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x16_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x16_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x16_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x16_F32F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x16_F32F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x16_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x16_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x16_F32F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x16_F32F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x16_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x16_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x16_F32F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x16_F32F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x16_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x16_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x16_F32F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x16_F32F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x16_F32F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x16_F32F16F16_RS{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x16_F32F16F16_RS{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: bfloat16_t ; Input B: bfloat16_t + else if constexpr (is_same_v && is_same_v) { + static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x16_F32BF16BF16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x16_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x16_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x16_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x16_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x16_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x16_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x16_F32BF16BF16_RS{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x16_F32BF16BF16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x16_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x16_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x16_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x16_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x16_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x16_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x16_F32BF16BF16_RS{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x16_F32BF16BF16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x16_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x16_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x16_F32BF16BF16_RS{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x16_F32BF16BF16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x16_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x16_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x16_F32BF16BF16_RS{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x16_F32BF16BF16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x16_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x16_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x16_F32BF16BF16_RS{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x16_F32BF16BF16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x16_F32BF16BF16_RS{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x16_F32BF16BF16_RS{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x16_F32BF16BF16_RS{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: tfloat32_t ; Input B: tfloat32_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 8 == 0, "Tile_K must be a multiple of 8."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x8_F32TF32TF32_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x8_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x8_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x8_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x8_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x8_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x8_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x8_F32TF32TF32_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x8_F32TF32TF32_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x8_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x8_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x8_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x8_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x8_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x8_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x8_F32TF32TF32_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x8_F32TF32TF32_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x8_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x8_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x8_F32TF32TF32_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x8_F32TF32TF32_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x8_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x8_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x8_F32TF32TF32_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x8_F32TF32TF32_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x8_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x8_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x8_F32TF32TF32_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x8_F32TF32TF32_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x8_F32TF32TF32_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x8_F32TF32TF32_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x8_F32TF32TF32_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e4m3_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_F32E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x32_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x32_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x32_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x32_F32E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_F32E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x32_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x32_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x32_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x32_F32E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_F32E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x32_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x32_F32E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_F32E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x32_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x32_F32E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_F32E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x32_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x32_F32E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_F32E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_F32E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_F32E4M3E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_F32E4M3E4M3_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e4m3_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_F32E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x32_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x32_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x32_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x32_F32E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_F32E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x32_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x32_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x32_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x32_F32E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_F32E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x32_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x32_F32E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_F32E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x32_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x32_F32E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_F32E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x32_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x32_F32E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_F32E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_F32E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_F32E4M3E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_F32E4M3E5M2_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e5m2_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_F32E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x32_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x32_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x32_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x32_F32E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_F32E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x32_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x32_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x32_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x32_F32E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_F32E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x32_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x32_F32E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_F32E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x32_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x32_F32E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_F32E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x32_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x32_F32E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_F32E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_F32E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_F32E5M2E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_F32E5M2E4M3_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e5m2_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_F32E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::MMA_64x248x32_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::MMA_64x232x32_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::MMA_64x216x32_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::MMA_64x200x32_F32E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_F32E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::MMA_64x184x32_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::MMA_64x168x32_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::MMA_64x152x32_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::MMA_64x136x32_F32E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_F32E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::MMA_64x120x32_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::MMA_64x104x32_F32E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_F32E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::MMA_64x88x32_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::MMA_64x72x32_F32E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_F32E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::MMA_64x56x32_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::MMA_64x40x32_F32E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_F32E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_F32E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_F32E5M2E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_F32E5M2E5M2_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + else { + static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration."); + } + } + + // S32 accumulator + else if constexpr (is_same_v) { + + // Input A: int8_t ; Input B: int8_t + if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_S32S8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_S32S8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_S32S8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_S32S8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_S32S8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_S32S8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_S32S8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_S32S8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_S32S8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_S32S8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_S32S8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_S32S8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_S32S8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_S32S8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_S32S8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_S32S8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_S32S8S8_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_S32S8S8_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: int8_t ; Input B: uint8_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_S32S8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_S32S8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_S32S8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_S32S8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_S32S8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_S32S8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_S32S8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_S32S8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_S32S8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_S32S8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_S32S8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_S32S8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_S32S8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_S32S8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_S32S8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_S32S8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_S32S8U8_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_S32S8U8_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: uint8_t ; Input B: int8_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_S32U8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_S32U8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_S32U8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_S32U8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_S32U8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_S32U8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_S32U8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_S32U8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_S32U8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_S32U8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_S32U8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_S32U8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_S32U8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_S32U8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_S32U8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_S32U8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_S32U8S8_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_S32U8S8_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: uint8_t ; Input B: uint8_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::MMA_64x256x32_S32U8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::MMA_64x240x32_S32U8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::MMA_64x224x32_S32U8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::MMA_64x208x32_S32U8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::MMA_64x192x32_S32U8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::MMA_64x176x32_S32U8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::MMA_64x160x32_S32U8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::MMA_64x144x32_S32U8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::MMA_64x128x32_S32U8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::MMA_64x112x32_S32U8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::MMA_64x96x32_S32U8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::MMA_64x80x32_S32U8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::MMA_64x64x32_S32U8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::MMA_64x48x32_S32U8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::MMA_64x32x32_S32U8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::MMA_64x24x32_S32U8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::MMA_64x16x32_S32U8U8_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::MMA_64x8x32_S32U8U8_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + else { + static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration."); + } + } + + // Unknown accumulator type + else { + static_assert(sizeof(ElementC) == 0, "Unknown ElementC accumulator type."); + } +} + +template < + class ElementA, + class ElementB, + class ElementC, + class TileShape_MNK, + GMMA::Major MajorA = GMMA::Major::K, + GMMA::Major MajorB = GMMA::Major::K, + auto... Args // e.g. GMMA::ScaleOut::One, [GMMA::ScaleIn::One, GMMA::ScaleIn::One] + // But most commonly leave empty for defaults +> +CUTE_HOST_DEVICE constexpr +auto +rs_op_selector_sparse() +{ + static_assert(is_static::value, "TileShape_MNK must be static."); + static_assert(rank(TileShape_MNK{}) == 3, "TileShape_MNK must be rank 3."); + static_assert(size<0>(TileShape_MNK{}) % 64 == 0, "Tile_M must be a multiple of 64."); + static_assert(MajorA == GMMA::Major::K, "Register source A operand GMMAs must have K-major A layout."); + auto Tile_N = size<1>(TileShape_MNK{}); + + // F16 accumulator + if constexpr (is_same_v) { + + // Input A: half_t ; Input B: half_t + if constexpr (is_same_v && is_same_v) { + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x32_F16F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x32_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x32_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x32_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x32_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x32_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x32_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x32_F16F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x32_F16F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x32_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x32_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x32_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x32_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x32_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x32_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x32_F16F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x32_F16F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x32_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x32_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x32_F16F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x32_F16F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x32_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x32_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x32_F16F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x32_F16F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x32_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x32_F16F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x32_F16F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x32_F16F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x32_F16F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x32_F16F16F16_RS{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x32_F16F16F16_RS{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e4m3_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_F16E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x64_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x64_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x64_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x64_F16E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_F16E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x64_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x64_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x64_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x64_F16E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_F16E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x64_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x64_F16E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_F16E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x64_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x64_F16E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_F16E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x64_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_F16E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x64_F16E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_F16E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_F16E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_F16E4M3E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_F16E4M3E4M3_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e4m3_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_F16E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x64_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x64_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x64_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x64_F16E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_F16E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x64_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x64_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x64_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x64_F16E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_F16E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x64_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x64_F16E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_F16E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x64_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x64_F16E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_F16E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x64_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_F16E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x64_F16E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_F16E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_F16E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_F16E4M3E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_F16E4M3E5M2_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e5m2_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_F16E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x64_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x64_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x64_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x64_F16E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_F16E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x64_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x64_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x64_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x64_F16E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_F16E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x64_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x64_F16E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_F16E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x64_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x64_F16E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_F16E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x64_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_F16E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x64_F16E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_F16E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_F16E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_F16E5M2E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_F16E5M2E4M3_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e5m2_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_F16E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x64_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x64_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x64_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x64_F16E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_F16E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x64_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x64_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x64_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x64_F16E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_F16E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x64_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x64_F16E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_F16E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x64_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x64_F16E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_F16E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x64_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_F16E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x64_F16E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_F16E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_F16E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_F16E5M2E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_F16E5M2E5M2_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + else { + static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration."); + } + } + + // F32 accumulator + else if constexpr (is_same_v) { + + // Input A: half_t ; Input B: half_t + if constexpr (is_same_v && is_same_v) { + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x32_F32F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x32_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x32_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x32_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x32_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x32_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x32_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x32_F32F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x32_F32F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x32_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x32_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x32_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x32_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x32_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x32_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x32_F32F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x32_F32F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x32_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x32_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x32_F32F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x32_F32F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x32_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x32_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x32_F32F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x32_F32F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x32_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x32_F32F16F16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x32_F32F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x32_F32F16F16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x32_F32F16F16_RS{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x32_F32F16F16_RS{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x32_F32F16F16_RS{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: bfloat16_t ; Input B: bfloat16_t + else if constexpr (is_same_v && is_same_v) { + static_assert(size<2>(TileShape_MNK{}) % 32 == 0, "Tile_K must be a multiple of 32."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x32_F32BF16BF16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x32_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x32_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x32_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x32_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x32_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x32_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x32_F32BF16BF16_RS{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x32_F32BF16BF16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x32_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x32_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x32_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x32_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x32_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x32_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x32_F32BF16BF16_RS{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x32_F32BF16BF16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x32_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x32_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x32_F32BF16BF16_RS{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x32_F32BF16BF16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x32_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x32_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x32_F32BF16BF16_RS{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x32_F32BF16BF16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x32_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x32_F32BF16BF16_RS{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x32_F32BF16BF16_RS{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x32_F32BF16BF16_RS{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x32_F32BF16BF16_RS{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x32_F32BF16BF16_RS{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x32_F32BF16BF16_RS{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: tfloat32_t ; Input B: tfloat32_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 16 == 0, "Tile_K must be a multiple of 16."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x16_F32TF32TF32_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x16_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x16_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x16_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x16_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x16_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x16_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x16_F32TF32TF32_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x16_F32TF32TF32_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x16_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x16_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x16_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x16_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x16_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x16_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x16_F32TF32TF32_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x16_F32TF32TF32_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x16_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x16_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x16_F32TF32TF32_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x16_F32TF32TF32_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x16_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x16_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x16_F32TF32TF32_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x16_F32TF32TF32_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x16_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x16_F32TF32TF32_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x16_F32TF32TF32_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x16_F32TF32TF32_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x16_F32TF32TF32_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x16_F32TF32TF32_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x16_F32TF32TF32_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e4m3_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_F32E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x64_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x64_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x64_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x64_F32E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_F32E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x64_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x64_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x64_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x64_F32E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_F32E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x64_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x64_F32E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_F32E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x64_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x64_F32E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_F32E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x64_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_F32E4M3E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x64_F32E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_F32E4M3E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_F32E4M3E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_F32E4M3E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_F32E4M3E4M3_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e4m3_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_F32E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x64_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x64_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x64_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x64_F32E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_F32E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x64_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x64_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x64_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x64_F32E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_F32E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x64_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x64_F32E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_F32E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x64_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x64_F32E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_F32E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x64_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_F32E4M3E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x64_F32E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_F32E4M3E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_F32E4M3E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_F32E4M3E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_F32E4M3E5M2_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e5m2_t ; Input B: float_e4m3_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_F32E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x64_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x64_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x64_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x64_F32E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_F32E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x64_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x64_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x64_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x64_F32E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_F32E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x64_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x64_F32E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_F32E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x64_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x64_F32E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_F32E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x64_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_F32E5M2E4M3_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x64_F32E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_F32E5M2E4M3_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_F32E5M2E4M3_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_F32E5M2E4M3_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_F32E5M2E4M3_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: float_e5m2_t ; Input B: float_e5m2_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_F32E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 248 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x248x64_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 232 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x232x64_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 216 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x216x64_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 200 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x200x64_F32E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_F32E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 184 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x184x64_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 168 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x168x64_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 152 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x152x64_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 136 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x136x64_F32E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_F32E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 120 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x120x64_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 104 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x104x64_F32E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_F32E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 88 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x88x64_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 72 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x72x64_F32E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_F32E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 56 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x56x64_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_F32E5M2E5M2_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 40 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x40x64_F32E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_F32E5M2E5M2_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_F32E5M2E5M2_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_F32E5M2E5M2_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_F32E5M2E5M2_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + else { + static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration."); + } + } + + // S32 accumulator + else if constexpr (is_same_v) { + + // Input A: int8_t ; Input B: int8_t + if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_S32S8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8S8_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8S8_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: int8_t ; Input B: uint8_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_S32S8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8U8_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8U8_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: uint8_t ; Input B: int8_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8S8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8S8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_S32U8S8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8S8_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8S8_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + // Input A: uint8_t ; Input B: uint8_t + else if constexpr (is_same_v && is_same_v) { + static_assert(MajorA == GMMA::Major::K, "MajorA must be GMMA::Major::K for this config."); + static_assert(MajorB == GMMA::Major::K, "MajorB must be GMMA::Major::K for this config."); + static_assert(size<2>(TileShape_MNK{}) % 64 == 0, "Tile_K must be a multiple of 64."); + + if constexpr (Tile_N % 256 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 240 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 224 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 208 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 192 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 176 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 160 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8U8_RS_TN{}; + } +#endif +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 144 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 128 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 112 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 96 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 80 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 64 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 48 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 32 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8U8_RS_TN{}; + } +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) + else if constexpr (Tile_N % 24 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x24x64_S32U8U8_RS_TN{}; + } +#endif + else if constexpr (Tile_N % 16 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8U8_RS_TN{}; + } + else if constexpr (Tile_N % 8 == 0) { + return SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8U8_RS_TN{}; + } + else { + static_assert(Tile_N % 8 == 0, "Tile_N must be a multiple of 8."); + } + } + + else { + static_assert(sizeof(ElementA) == 0, "No eligible GMMA operator for request configuration."); + } + } + + // Unknown accumulator type + else { + static_assert(sizeof(ElementC) == 0, "Unknown ElementC accumulator type."); + } +} + +} // end namespace SM90::GMMA +} // end namespace cute + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cute/arch/mma_sm90_desc.hpp b/include/cute/arch/mma_sm90_desc.hpp new file mode 100644 index 0000000000..a53a9748b4 --- /dev/null +++ b/include/cute/arch/mma_sm90_desc.hpp @@ -0,0 +1,156 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include + +#include + +// Config +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && defined(__CUDA_ARCH_FEAT_SM90_ALL)) +# define CUTE_ARCH_MMA_SM90A_ENABLED +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cute { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// GMMA Descriptor and utilities + +// GMMA enums and utilities +namespace SM90::GMMA { + +enum class LayoutType : uint8_t { + INTERLEAVE = 0, + B128 = 1, + B64 = 2, + B32 = 3, +}; + +CUTE_HOST_DEVICE char const* to_string(LayoutType const& t) { + switch (t) { + case LayoutType::INTERLEAVE: return "INTERLEAVE"; + case LayoutType::B128: return "B128"; + case LayoutType::B64: return "B64"; + case LayoutType::B32: return "B32"; + } + return nullptr; +} + +#if !defined(__CUDACC_RTC__) +// Output operator for all enums in this namespace +CUTE_HOST std::ostream& operator<<(std::ostream& os, LayoutType const& t) { + char const* s = to_string(t); + if (s) { + std::operator<<(os, s); // Explicit call to avoid ambiguity + } else { + os.setstate(std::ios_base::failbit); + } + return os; +} +#endif // !defined(__CUDACC_RTC__) + +} // end namespace SM90::GMMA + +union GmmaDescriptor +{ + CUTE_HOST_DEVICE constexpr + GmmaDescriptor() noexcept : desc_(0) {} + CUTE_HOST_DEVICE constexpr + GmmaDescriptor(uint64_t desc) noexcept : desc_(desc) {} + CUTE_HOST_DEVICE constexpr + GmmaDescriptor(GmmaDescriptor const& t) noexcept : desc_(t.desc_) {} + CUTE_HOST_DEVICE constexpr + GmmaDescriptor(GmmaDescriptor && t) noexcept : desc_(t.desc_) {} + + CUTE_HOST_DEVICE constexpr + GmmaDescriptor& operator=(GmmaDescriptor const& t) noexcept { + desc_ = t.desc_; + return *this; + } + + CUTE_HOST_DEVICE constexpr + GmmaDescriptor& operator=(GmmaDescriptor && t) noexcept { + desc_ = t.desc_; + return *this; + } + + uint64_t desc_; + uint32_t reg32_[2]; + uint16_t reg16_[4]; + + // Bitfield implementation avoids the need for shifts in assignment + struct { + // start_address, bit [0,14), 4LSB not included + uint16_t start_address_ : 14, : 2; // 14 bits [0,14), 2 bits unused + // leading dimension byte offset, bit [16,30), 4LSB not included + // For N: This is the stride from the first col to the second col of the 8x2 brick in INTERLEAVED + // Unused for all SWIZZLE_* layouts (and assumed to be 1) + // For T: This is the stride from the first 8 rows to the next 8 rows. + uint16_t leading_byte_offset_ : 14, : 2; // 14 bits [0,14), 2 bits unused + // stride dimension byte offset, bit [32,46), 4LSB not included + // For N: This is the stride from the first 8 rows to the next 8 rows. + // For T: This is the stride fro mthe first 8 cols to the next 8 cols. + uint16_t stride_byte_offset_ : 14, : 2; // 14 bits [0,14), 2 bits unused + // base_offset, bit [49,52) + // Valid only for SWIZZLE_128B and SWIZZLE_64B + uint8_t : 1, base_offset_ : 3, : 4; // 1 bit unused, 3 bits [1,4), 4 bits unused + // layout type, bit [62,64) + // SWIZZLE_NONE = 0, SWIZZLE_32B = 3, SWIZZLE_64B = 2, SWIZZLE_128B = 1 + uint8_t : 6, layout_type_ : 2; // 6 bits unused, 2 bits [6,8) + } bitfield; + + // Decay to a uint64_t + CUTE_HOST_DEVICE constexpr + operator uint64_t() const noexcept { return desc_; } +}; + +// Printer +CUTE_HOST_DEVICE void +print(GmmaDescriptor const& t) +{ +#if !defined(__CUDACC_RTC__) + printf("GmmaDescriptor: 0x%016llx\n", static_cast(t.desc_)); + printf(" start_addr : 0x%04x\n", t.bitfield.start_address_); + printf(" leading_off: 0x%04x (%d)\n", t.bitfield.leading_byte_offset_, t.bitfield.leading_byte_offset_); + printf(" stride_off : 0x%04x (%d)\n", t.bitfield.stride_byte_offset_, t.bitfield.stride_byte_offset_); + printf(" base_offset: 0x%01x\n", t.bitfield.base_offset_); + printf(" layout_type: 0x%01x (%s)\n", t.bitfield.layout_type_, to_string(static_cast(t.bitfield.layout_type_))); +#endif // !defined(__CUDACC_RTC__) +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cute + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cute/arch/mma_sm90_gmma.hpp b/include/cute/arch/mma_sm90_gmma.hpp new file mode 100644 index 0000000000..d809aa4a63 --- /dev/null +++ b/include/cute/arch/mma_sm90_gmma.hpp @@ -0,0 +1,20974 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTE_HOST_DEVICE + +#include "cutlass/arch/synclog.hpp" + +// Config +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && defined(__CUDA_ARCH_FEAT_SM90_ALL)) +# define CUTE_ARCH_MMA_SM90A_ENABLED +#endif + +namespace cute { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Warpgroup sync primitives + +CUTE_HOST_DEVICE +void +warpgroup_arrive() +{ +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_warpgroup_arrive(__LINE__); + asm volatile ("wgmma.fence.sync.aligned;\n" ::: "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use wgmma.fence without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif +} + +template +CUTE_HOST_DEVICE +void +warpgroup_wait() +{ + static_assert(N >= 0 && N <= 7, "WGMMA wait: N must be in range [0, 7]"); +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_warpgroup_wait(__LINE__, N); + asm volatile("wgmma.wait_group.sync.aligned %0;\n" :: "n"(N) : "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use wgmma.wait_group without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif +} + +// Marks the commit point for one or more sized batch of warpgroup MMAs. +CUTE_HOST_DEVICE +void +warpgroup_commit_batch() +{ +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_warpgroup_commit_batch(__LINE__); + asm volatile("wgmma.commit_group.sync.aligned;\n" ::: "memory"); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use wgmma.commit_group without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif +} + +CUTE_HOST_DEVICE +void +warpgroup_fence_operand(uint32_t& reg) { + // MSVC emits a build error for 'asm volatile' + // even if it only occurs in a __device__ function. + // This prevents the error. +#if defined(__CUDA_ARCH__) + asm volatile("" : "+r"(reg) :: "memory"); +#endif +} + +CUTE_HOST_DEVICE +void +warpgroup_fence_operand(float& reg) { +#if defined(__CUDA_ARCH__) + asm volatile("" : "+f"(reg) :: "memory"); +#endif +} + +namespace SM90::GMMA { + +enum class Major { + K = 0, + MN = 1 +}; + +enum class ScaleOut { + Zero = 0, + One = 1 +}; + +enum class ScaleIn { + Neg = -1, + One = 1 +}; + +enum class SparseSel { + Zero = 0, + One = 1 +}; + + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// GMMA PTX definitions: C = (scaleA * A) * (scaleB * B) + (scaleD * C) +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %4, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k16.f16.f16.f16 " + "{%0, %1}," + " %2," + " %3," + " p, %5, %6, %7, %8;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %7, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k16.f16.f16.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + " %6," + " p, %8, %9, %10;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k16.f16.f16.f16 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p, %7, %8, %9, %10;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k16.f16.f16.f16 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p, %10, %11, %12;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p, %11, %12, %13, %14;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p, %14, %15, %16;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p, %19, %20, %21, %22;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p, %22, %23, %24;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p, %27, %28, %29, %30;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p, %30, %31, %32;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p, %35, %36, %37, %38;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p, %38, %39, %40;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52, %53, %54;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55, %56;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p, %67, %68, %69, %70;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p, %70, %71, %72;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p, %7, %8, %9, %10;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k16.f32.f16.f16 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p, %10, %11, %12;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p, %11, %12, %13, %14;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p, %14, %15, %16;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p, %19, %20, %21, %22;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p, %22, %23, %24;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p, %35, %36, %37, %38;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p, %38, %39, %40;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52, %53, %54;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55, %56;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p, %67, %68, %69, %70;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p, %70, %71, %72;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p, %99, %100, %101, %102;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p, %102, %103, %104;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p, %131, %132, %133, %134;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p, %134, %135, %136;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p, %7, %8, %9, %10;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k16.f32.bf16.bf16 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p, %10, %11, %12;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p, %11, %12, %13, %14;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p, %14, %15, %16;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p, %19, %20, %21, %22;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p, %22, %23, %24;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p, %35, %36, %37, %38;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p, %38, %39, %40;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52, %53, %54;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55, %56;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p, %67, %68, %69, %70;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p, %70, %71, %72;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p, %99, %100, %101, %102;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p, %102, %103, %104;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p, %131, %132, %133, %134;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p, %134, %135, %136;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k8.f32.tf32.tf32 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p, %7, %8;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k8.f32.tf32.tf32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p, %10, %11;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p, %11, %12;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p, %14, %15;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p, %19, %20;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p, %22, %23;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p, %35, %36;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p, %38, %39;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p, %67, %68;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p, %70, %71;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p, %99, %100;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p, %102, %103;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p, %131, %132;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p, %134, %135;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN S32+=S8*S8 +struct MMA_64x8x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN S32+=S8*S8 +struct MMA_64x8x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=S8*S8 +struct MMA_64x16x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=S8*S8 +struct MMA_64x16x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN S32+=S8*S8 +struct MMA_64x32x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN S32+=S8*S8 +struct MMA_64x32x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN S32+=S8*S8 +struct MMA_64x64x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN S32+=S8*S8 +struct MMA_64x64x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN S32+=S8*S8 +struct MMA_64x96x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN S32+=S8*S8 +struct MMA_64x96x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN S32+=S8*S8 +struct MMA_64x128x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN S32+=S8*S8 +struct MMA_64x128x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN S32+=S8*S8 +struct MMA_64x192x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN S32+=S8*S8 +struct MMA_64x192x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN S32+=S8*S8 +struct MMA_64x256x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN S32+=S8*S8 +struct MMA_64x256x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN S32+=S8*S8 +struct MMA_64x8x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN S32+=S8*S8 +struct MMA_64x8x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=S8*S8 +struct MMA_64x16x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=S8*S8 +struct MMA_64x16x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN S32+=S8*S8 +struct MMA_64x32x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN S32+=S8*S8 +struct MMA_64x32x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN S32+=S8*S8 +struct MMA_64x64x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN S32+=S8*S8 +struct MMA_64x64x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN S32+=S8*S8 +struct MMA_64x96x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN S32+=S8*S8 +struct MMA_64x96x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN S32+=S8*S8 +struct MMA_64x128x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN S32+=S8*S8 +struct MMA_64x128x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN S32+=S8*S8 +struct MMA_64x192x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN S32+=S8*S8 +struct MMA_64x192x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN S32+=S8*S8 +struct MMA_64x256x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN S32+=S8*S8 +struct MMA_64x256x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN S32+=S8*U8 +struct MMA_64x8x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.u8 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN S32+=S8*U8 +struct MMA_64x8x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=S8*U8 +struct MMA_64x16x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=S8*U8 +struct MMA_64x16x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN S32+=S8*U8 +struct MMA_64x32x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN S32+=S8*U8 +struct MMA_64x32x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN S32+=S8*U8 +struct MMA_64x64x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN S32+=S8*U8 +struct MMA_64x64x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN S32+=S8*U8 +struct MMA_64x96x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN S32+=S8*U8 +struct MMA_64x96x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN S32+=S8*U8 +struct MMA_64x128x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN S32+=S8*U8 +struct MMA_64x128x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN S32+=S8*U8 +struct MMA_64x192x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN S32+=S8*U8 +struct MMA_64x192x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN S32+=S8*U8 +struct MMA_64x256x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN S32+=S8*U8 +struct MMA_64x256x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN S32+=S8*U8 +struct MMA_64x8x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.u8 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN S32+=S8*U8 +struct MMA_64x8x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=S8*U8 +struct MMA_64x16x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=S8*U8 +struct MMA_64x16x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN S32+=S8*U8 +struct MMA_64x32x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN S32+=S8*U8 +struct MMA_64x32x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN S32+=S8*U8 +struct MMA_64x64x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN S32+=S8*U8 +struct MMA_64x64x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN S32+=S8*U8 +struct MMA_64x96x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN S32+=S8*U8 +struct MMA_64x96x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN S32+=S8*U8 +struct MMA_64x128x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN S32+=S8*U8 +struct MMA_64x128x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN S32+=S8*U8 +struct MMA_64x192x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN S32+=S8*U8 +struct MMA_64x192x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN S32+=S8*U8 +struct MMA_64x256x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN S32+=S8*U8 +struct MMA_64x256x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN S32+=U8*S8 +struct MMA_64x8x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.s8 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN S32+=U8*S8 +struct MMA_64x8x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=U8*S8 +struct MMA_64x16x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=U8*S8 +struct MMA_64x16x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN S32+=U8*S8 +struct MMA_64x32x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN S32+=U8*S8 +struct MMA_64x32x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN S32+=U8*S8 +struct MMA_64x64x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN S32+=U8*S8 +struct MMA_64x64x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN S32+=U8*S8 +struct MMA_64x96x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN S32+=U8*S8 +struct MMA_64x96x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN S32+=U8*S8 +struct MMA_64x128x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN S32+=U8*S8 +struct MMA_64x128x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN S32+=U8*S8 +struct MMA_64x192x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN S32+=U8*S8 +struct MMA_64x192x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN S32+=U8*S8 +struct MMA_64x256x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN S32+=U8*S8 +struct MMA_64x256x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN S32+=U8*S8 +struct MMA_64x8x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.s8 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN S32+=U8*S8 +struct MMA_64x8x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=U8*S8 +struct MMA_64x16x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=U8*S8 +struct MMA_64x16x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN S32+=U8*S8 +struct MMA_64x32x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN S32+=U8*S8 +struct MMA_64x32x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN S32+=U8*S8 +struct MMA_64x64x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN S32+=U8*S8 +struct MMA_64x64x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN S32+=U8*S8 +struct MMA_64x96x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN S32+=U8*S8 +struct MMA_64x96x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN S32+=U8*S8 +struct MMA_64x128x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN S32+=U8*S8 +struct MMA_64x128x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN S32+=U8*S8 +struct MMA_64x192x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN S32+=U8*S8 +struct MMA_64x192x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN S32+=U8*S8 +struct MMA_64x256x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN S32+=U8*S8 +struct MMA_64x256x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN S32+=U8*U8 +struct MMA_64x8x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN S32+=U8*U8 +struct MMA_64x8x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=U8*U8 +struct MMA_64x16x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=U8*U8 +struct MMA_64x16x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN S32+=U8*U8 +struct MMA_64x32x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN S32+=U8*U8 +struct MMA_64x32x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN S32+=U8*U8 +struct MMA_64x64x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN S32+=U8*U8 +struct MMA_64x64x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN S32+=U8*U8 +struct MMA_64x96x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN S32+=U8*U8 +struct MMA_64x96x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN S32+=U8*U8 +struct MMA_64x128x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN S32+=U8*U8 +struct MMA_64x128x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN S32+=U8*U8 +struct MMA_64x192x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN S32+=U8*U8 +struct MMA_64x192x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN S32+=U8*U8 +struct MMA_64x256x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN S32+=U8*U8 +struct MMA_64x256x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN S32+=U8*U8 +struct MMA_64x8x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN S32+=U8*U8 +struct MMA_64x8x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=U8*U8 +struct MMA_64x16x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN S32+=U8*U8 +struct MMA_64x16x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN S32+=U8*U8 +struct MMA_64x32x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN S32+=U8*U8 +struct MMA_64x32x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN S32+=U8*U8 +struct MMA_64x64x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN S32+=U8*U8 +struct MMA_64x64x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN S32+=U8*U8 +struct MMA_64x96x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN S32+=U8*U8 +struct MMA_64x96x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN S32+=U8*U8 +struct MMA_64x128x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN S32+=U8*U8 +struct MMA_64x128x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN S32+=U8*U8 +struct MMA_64x192x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN S32+=U8*U8 +struct MMA_64x192x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN S32+=U8*U8 +struct MMA_64x256x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN S32+=U8*U8 +struct MMA_64x256x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %4, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f16.e4m3.e4m3 " + "{%0, %1}," + " %2," + " %3," + " p, %5, %6;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %7, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f16.e4m3.e4m3 " + "{%0, %1}," + "{%2, %3, %4, %5}," + " %6," + " p, %8, %9;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p, %7, %8;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p, %10, %11;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p, %7, %8;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p, %10, %11;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p, %11, %12;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p, %14, %15;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p, %11, %12;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p, %14, %15;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p, %19, %20;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p, %22, %23;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p, %19, %20;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p, %22, %23;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p, %35, %36;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p, %38, %39;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p, %27, %28;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p, %30, %31;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p, %35, %36;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p, %38, %39;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p, %67, %68;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p, %70, %71;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p, %99, %100;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p, %102, %103;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p, %67, %68;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p, %70, %71;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p, %131, %132;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p, %134, %135;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %4, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f16.e4m3.e5m2 " + "{%0, %1}," + " %2," + " %3," + " p, %5, %6;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %7, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f16.e4m3.e5m2 " + "{%0, %1}," + "{%2, %3, %4, %5}," + " %6," + " p, %8, %9;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p, %7, %8;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p, %10, %11;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p, %7, %8;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p, %10, %11;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p, %11, %12;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p, %14, %15;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p, %11, %12;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p, %14, %15;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p, %19, %20;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p, %22, %23;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p, %19, %20;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p, %22, %23;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p, %35, %36;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p, %38, %39;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p, %27, %28;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p, %30, %31;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p, %35, %36;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p, %38, %39;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p, %67, %68;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p, %70, %71;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p, %99, %100;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p, %102, %103;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p, %67, %68;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p, %70, %71;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p, %131, %132;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p, %134, %135;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %4, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f16.e5m2.e4m3 " + "{%0, %1}," + " %2," + " %3," + " p, %5, %6;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %7, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f16.e5m2.e4m3 " + "{%0, %1}," + "{%2, %3, %4, %5}," + " %6," + " p, %8, %9;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p, %7, %8;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p, %10, %11;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p, %7, %8;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p, %10, %11;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p, %11, %12;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p, %14, %15;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p, %11, %12;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p, %14, %15;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p, %19, %20;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p, %22, %23;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p, %19, %20;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p, %22, %23;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p, %35, %36;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p, %38, %39;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p, %27, %28;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p, %30, %31;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p, %35, %36;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p, %38, %39;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p, %67, %68;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p, %70, %71;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p, %99, %100;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p, %102, %103;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p, %67, %68;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p, %70, %71;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p, %131, %132;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p, %134, %135;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %4, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f16.e5m2.e5m2 " + "{%0, %1}," + " %2," + " %3," + " p, %5, %6;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %7, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f16.e5m2.e5m2 " + "{%0, %1}," + "{%2, %3, %4, %5}," + " %6," + " p, %8, %9;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p, %7, %8;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x8x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x8x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n8k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p, %10, %11;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x8x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " p, %7, %8;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " p, %10, %11;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p, %11, %12;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x16x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x16x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n16k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p, %14, %15;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x16x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " p, %11, %12;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " p, %14, %15;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p, %19, %20;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x32x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x32x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n32k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p, %22, %23;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x32x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " p, %19, %20;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " p, %22, %23;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p, %35, %36;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x64x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x64x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n64k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p, %38, %39;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x64x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p, %27, %28;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p, %30, %31;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x96x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x96x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n96k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x96x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " p, %35, %36;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " p, %38, %39;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p, %67, %68;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x128x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x128x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n128k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p, %70, %71;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x128x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " p, %51, %52;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " p, %54, %55;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %98, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " p, %99, %100;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x192x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x192x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %101, 0;\n" + "wgmma.mma_async.sync.aligned.m64n192k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " p, %102, %103;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x192x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " p, %67, %68;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " p, %70, %71;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %130, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " p, %131, %132;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x256x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x256x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %133, 0;\n" + "wgmma.mma_async.sync.aligned.m64n256k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " p, %134, %135;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x256x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace SM90::GMMA + +} // namespace cute + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +#include "mma_sm90_gmma_ext.hpp" +#endif diff --git a/include/cute/arch/mma_sm90_gmma_ext.hpp b/include/cute/arch/mma_sm90_gmma_ext.hpp new file mode 100644 index 0000000000..10a36aff80 --- /dev/null +++ b/include/cute/arch/mma_sm90_gmma_ext.hpp @@ -0,0 +1,56445 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include // CUTE_HOST_DEVICE + +#include "cutlass/arch/synclog.hpp" + +// Config +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && defined(__CUDA_ARCH_FEAT_SM90_ALL)) +# define CUTE_ARCH_MMA_SM90A_ENABLED +#endif + +namespace cute { + +namespace SM90::GMMA { + +// GMMA 64x24x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5}," + " %6," + " %7," + " p, %9, %10, %11, %12;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5}," + "{%6, %7, %8, %9}," + " %10," + " p, %12, %13, %14;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + " %10," + " %11," + " p, %13, %14, %15, %16;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + "{%10, %11, %12, %13}," + " %14," + " p, %16, %17, %18;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p, %15, %16, %17, %18;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p, %18, %19, %20;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + " %14," + " %15," + " p, %17, %18, %19, %20;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + "{%14, %15, %16, %17}," + " %18," + " p, %20, %21, %22;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + " %18," + " %19," + " p, %21, %22, %23, %24;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + "{%18, %19, %20, %21}," + " %22," + " p, %24, %25, %26;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " p, %23, %24, %25, %26;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " p, %26, %27, %28;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + " %22," + " %23," + " p, %25, %26, %27, %28;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + "{%22, %23, %24, %25}," + " %26," + " p, %28, %29, %30;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + " %26," + " %27," + " p, %29, %30, %31, %32;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + "{%26, %27, %28, %29}," + " %30," + " p, %32, %33, %34;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " p, %31, %32, %33, %34;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " p, %34, %35, %36;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + " %30," + " %31," + " p, %33, %34, %35, %36;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + "{%30, %31, %32, %33}," + " %34," + " p, %36, %37, %38;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + " %34," + " %35," + " p, %37, %38, %39, %40;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + "{%34, %35, %36, %37}," + " %38," + " p, %40, %41, %42;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " p, %39, %40, %41, %42;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " p, %42, %43, %44;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + " %38," + " %39," + " p, %41, %42, %43, %44;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + "{%38, %39, %40, %41}," + " %42," + " p, %44, %45, %46;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p, %43, %44, %45, %46;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p, %46, %47, %48;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + " %42," + " %43," + " p, %45, %46, %47, %48;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + "{%42, %43, %44, %45}," + " %46," + " p, %48, %49, %50;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " p, %47, %48, %49, %50;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " p, %50, %51, %52;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + " %46," + " %47," + " p, %49, %50, %51, %52;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + "{%46, %47, %48, %49}," + " %50," + " p, %52, %53, %54;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + " %50," + " %51," + " p, %53, %54, %55, %56;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + "{%50, %51, %52, %53}," + " %54," + " p, %56, %57, %58;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " p, %55, %56, %57, %58;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " p, %58, %59, %60;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + " %54," + " %55," + " p, %57, %58, %59, %60;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + "{%54, %55, %56, %57}," + " %58," + " p, %60, %61, %62;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p, %59, %60, %61, %62;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p, %62, %63, %64;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + " %58," + " %59," + " p, %61, %62, %63, %64;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + "{%58, %59, %60, %61}," + " %62," + " p, %64, %65, %66;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " p, %63, %64, %65, %66;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " p, %66, %67, %68;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x16_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + " %62," + " %63," + " p, %65, %66, %67, %68;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x16_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x16 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x16_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k16.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + "{%62, %63, %64, %65}," + " %66," + " p, %68, %69, %70;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x16_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p, %15, %16, %17, %18;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p, %18, %19, %20;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " p, %23, %24, %25, %26;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " p, %26, %27, %28;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p, %27, %28, %29, %30;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p, %30, %31, %32;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " p, %31, %32, %33, %34;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " p, %34, %35, %36;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " p, %39, %40, %41, %42;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " p, %42, %43, %44;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p, %43, %44, %45, %46;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p, %46, %47, %48;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " p, %47, %48, %49, %50;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " p, %50, %51, %52;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " p, %55, %56, %57, %58;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " p, %58, %59, %60;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p, %59, %60, %61, %62;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p, %62, %63, %64;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " p, %63, %64, %65, %66;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " p, %66, %67, %68;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %70, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + " %68," + " %69," + " p, %71, %72, %73, %74;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %73, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + "{%68, %69, %70, %71}," + " %72," + " p, %74, %75, %76;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %74, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " p, %75, %76, %77, %78;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %77, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " p, %78, %79, %80;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %78, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + " %76," + " %77," + " p, %79, %80, %81, %82;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %81, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + "{%76, %77, %78, %79}," + " %80," + " p, %82, %83, %84;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %82, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " p, %83, %84, %85, %86;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %85, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " p, %86, %87, %88;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %86, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + " %84," + " %85," + " p, %87, %88, %89, %90;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %89, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + "{%84, %85, %86, %87}," + " %88," + " p, %90, %91, %92;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %90, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " p, %91, %92, %93, %94;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %93, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " p, %94, %95, %96;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %94, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + " %92," + " %93," + " p, %95, %96, %97, %98;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %97, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + "{%92, %93, %94, %95}," + " %96," + " p, %98, %99, %100;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %102, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + " %100," + " %101," + " p, %103, %104, %105, %106;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %105, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + "{%100, %101, %102, %103}," + " %104," + " p, %106, %107, %108;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %106, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " p, %107, %108, %109, %110;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %109, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " p, %110, %111, %112;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %110, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + " %108," + " %109," + " p, %111, %112, %113, %114;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %113, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + "{%108, %109, %110, %111}," + " %112," + " p, %114, %115, %116;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %114, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " p, %115, %116, %117, %118;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %117, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " p, %118, %119, %120;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %118, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + " %116," + " %117," + " p, %119, %120, %121, %122;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %121, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + "{%116, %117, %118, %119}," + " %120," + " p, %122, %123, %124;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %122, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " p, %123, %124, %125, %126;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %125, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " p, %126, %127, %128;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x16_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %126, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + " %124," + " %125," + " p, %127, %128, %129, %130;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x16_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x16 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x16_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %129, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k16.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + "{%124, %125, %126, %127}," + " %128," + " p, %130, %131, %132;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x16_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p, %15, %16, %17, %18;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p, %18, %19, %20;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " p, %23, %24, %25, %26;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " p, %26, %27, %28;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p, %27, %28, %29, %30;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p, %30, %31, %32;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " p, %31, %32, %33, %34;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " p, %34, %35, %36;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " p, %39, %40, %41, %42;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " p, %42, %43, %44;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p, %43, %44, %45, %46;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p, %46, %47, %48;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " p, %47, %48, %49, %50;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " p, %50, %51, %52;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " p, %55, %56, %57, %58;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " p, %58, %59, %60;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p, %59, %60, %61, %62;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p, %62, %63, %64;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " p, %63, %64, %65, %66;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " p, %66, %67, %68;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %70, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + " %68," + " %69," + " p, %71, %72, %73, %74;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %73, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + "{%68, %69, %70, %71}," + " %72," + " p, %74, %75, %76;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %74, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " p, %75, %76, %77, %78;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %77, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " p, %78, %79, %80;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %78, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + " %76," + " %77," + " p, %79, %80, %81, %82;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %81, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + "{%76, %77, %78, %79}," + " %80," + " p, %82, %83, %84;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %82, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " p, %83, %84, %85, %86;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %85, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " p, %86, %87, %88;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %86, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + " %84," + " %85," + " p, %87, %88, %89, %90;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %89, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + "{%84, %85, %86, %87}," + " %88," + " p, %90, %91, %92;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %90, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " p, %91, %92, %93, %94;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %93, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " p, %94, %95, %96;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %94, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + " %92," + " %93," + " p, %95, %96, %97, %98;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %97, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + "{%92, %93, %94, %95}," + " %96," + " p, %98, %99, %100;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %102, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + " %100," + " %101," + " p, %103, %104, %105, %106;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %105, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + "{%100, %101, %102, %103}," + " %104," + " p, %106, %107, %108;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %106, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " p, %107, %108, %109, %110;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %109, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " p, %110, %111, %112;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %110, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + " %108," + " %109," + " p, %111, %112, %113, %114;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %113, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + "{%108, %109, %110, %111}," + " %112," + " p, %114, %115, %116;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %114, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " p, %115, %116, %117, %118;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %117, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " p, %118, %119, %120;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %118, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + " %116," + " %117," + " p, %119, %120, %121, %122;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %121, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + "{%116, %117, %118, %119}," + " %120," + " p, %122, %123, %124;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %122, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " p, %123, %124, %125, %126;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %125, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " p, %126, %127, %128;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x16_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %126, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + " %124," + " %125," + " p, %127, %128, %129, %130;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x16_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x16 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x16_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %129, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k16.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + "{%124, %125, %126, %127}," + " %128," + " p, %130, %131, %132;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x16_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p, %15, %16;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p, %18, %19;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " p, %23, %24;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " p, %26, %27;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p, %27, %28;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p, %30, %31;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " p, %31, %32;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " p, %34, %35;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " p, %39, %40;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " p, %42, %43;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p, %43, %44;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p, %46, %47;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " p, %47, %48;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " p, %50, %51;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " p, %55, %56;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " p, %58, %59;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p, %59, %60;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p, %62, %63;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " p, %63, %64;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " p, %66, %67;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %70, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + " %68," + " %69," + " p, %71, %72;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %73, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + "{%68, %69, %70, %71}," + " %72," + " p, %74, %75;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %74, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " p, %75, %76;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %77, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " p, %78, %79;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %78, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + " %76," + " %77," + " p, %79, %80;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %81, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + "{%76, %77, %78, %79}," + " %80," + " p, %82, %83;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %82, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " p, %83, %84;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %85, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " p, %86, %87;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %86, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + " %84," + " %85," + " p, %87, %88;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %89, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + "{%84, %85, %86, %87}," + " %88," + " p, %90, %91;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %90, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " p, %91, %92;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %93, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " p, %94, %95;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %94, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + " %92," + " %93," + " p, %95, %96;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %97, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + "{%92, %93, %94, %95}," + " %96," + " p, %98, %99;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %102, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + " %100," + " %101," + " p, %103, %104;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %105, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + "{%100, %101, %102, %103}," + " %104," + " p, %106, %107;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %106, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " p, %107, %108;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %109, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " p, %110, %111;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %110, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + " %108," + " %109," + " p, %111, %112;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %113, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + "{%108, %109, %110, %111}," + " %112," + " p, %114, %115;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %114, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " p, %115, %116;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %117, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " p, %118, %119;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %118, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + " %116," + " %117," + " p, %119, %120;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %121, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + "{%116, %117, %118, %119}," + " %120," + " p, %122, %123;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %122, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " p, %123, %124;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %125, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " p, %126, %127;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x8_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %126, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + " %124," + " %125," + " p, %127, %128;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x8_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x8 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x8_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %129, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k8.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + "{%124, %125, %126, %127}," + " %128," + " p, %130, %131;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x8_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN S32+=S8*S8 +struct MMA_64x24x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN S32+=S8*S8 +struct MMA_64x24x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN S32+=S8*S8 +struct MMA_64x48x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN S32+=S8*S8 +struct MMA_64x48x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN S32+=S8*S8 +struct MMA_64x80x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN S32+=S8*S8 +struct MMA_64x80x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN S32+=S8*S8 +struct MMA_64x112x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN S32+=S8*S8 +struct MMA_64x112x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN S32+=S8*S8 +struct MMA_64x144x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %74, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN S32+=S8*S8 +struct MMA_64x144x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %74, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN S32+=S8*S8 +struct MMA_64x160x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %82, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN S32+=S8*S8 +struct MMA_64x160x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %82, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN S32+=S8*S8 +struct MMA_64x176x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %90, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN S32+=S8*S8 +struct MMA_64x176x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %90, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN S32+=S8*S8 +struct MMA_64x208x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %106, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN S32+=S8*S8 +struct MMA_64x208x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %106, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN S32+=S8*S8 +struct MMA_64x224x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %114, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN S32+=S8*S8 +struct MMA_64x224x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %114, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN S32+=S8*S8 +struct MMA_64x240x32_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %122, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN S32+=S8*S8 +struct MMA_64x240x32_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %122, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN S32+=S8*S8 +struct MMA_64x24x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN S32+=S8*S8 +struct MMA_64x24x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN S32+=S8*S8 +struct MMA_64x48x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN S32+=S8*S8 +struct MMA_64x48x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN S32+=S8*S8 +struct MMA_64x80x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN S32+=S8*S8 +struct MMA_64x80x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN S32+=S8*S8 +struct MMA_64x112x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN S32+=S8*S8 +struct MMA_64x112x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN S32+=S8*S8 +struct MMA_64x144x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %77, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN S32+=S8*S8 +struct MMA_64x144x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %77, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN S32+=S8*S8 +struct MMA_64x160x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %85, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN S32+=S8*S8 +struct MMA_64x160x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %85, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN S32+=S8*S8 +struct MMA_64x176x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %93, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN S32+=S8*S8 +struct MMA_64x176x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %93, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN S32+=S8*S8 +struct MMA_64x208x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %109, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN S32+=S8*S8 +struct MMA_64x208x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %109, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN S32+=S8*S8 +struct MMA_64x224x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %117, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN S32+=S8*S8 +struct MMA_64x224x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %117, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN S32+=S8*S8 +struct MMA_64x240x32_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %125, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN S32+=S8*S8 +struct MMA_64x240x32_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %125, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN S32+=S8*U8 +struct MMA_64x24x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN S32+=S8*U8 +struct MMA_64x24x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN S32+=S8*U8 +struct MMA_64x48x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN S32+=S8*U8 +struct MMA_64x48x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN S32+=S8*U8 +struct MMA_64x80x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN S32+=S8*U8 +struct MMA_64x80x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN S32+=S8*U8 +struct MMA_64x112x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN S32+=S8*U8 +struct MMA_64x112x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN S32+=S8*U8 +struct MMA_64x144x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %74, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN S32+=S8*U8 +struct MMA_64x144x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %74, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN S32+=S8*U8 +struct MMA_64x160x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %82, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN S32+=S8*U8 +struct MMA_64x160x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %82, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN S32+=S8*U8 +struct MMA_64x176x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %90, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN S32+=S8*U8 +struct MMA_64x176x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %90, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN S32+=S8*U8 +struct MMA_64x208x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %106, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN S32+=S8*U8 +struct MMA_64x208x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %106, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN S32+=S8*U8 +struct MMA_64x224x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %114, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN S32+=S8*U8 +struct MMA_64x224x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %114, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN S32+=S8*U8 +struct MMA_64x240x32_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %122, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN S32+=S8*U8 +struct MMA_64x240x32_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %122, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN S32+=S8*U8 +struct MMA_64x24x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN S32+=S8*U8 +struct MMA_64x24x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN S32+=S8*U8 +struct MMA_64x48x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN S32+=S8*U8 +struct MMA_64x48x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN S32+=S8*U8 +struct MMA_64x80x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN S32+=S8*U8 +struct MMA_64x80x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN S32+=S8*U8 +struct MMA_64x112x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN S32+=S8*U8 +struct MMA_64x112x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN S32+=S8*U8 +struct MMA_64x144x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %77, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN S32+=S8*U8 +struct MMA_64x144x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %77, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN S32+=S8*U8 +struct MMA_64x160x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %85, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN S32+=S8*U8 +struct MMA_64x160x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %85, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN S32+=S8*U8 +struct MMA_64x176x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %93, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN S32+=S8*U8 +struct MMA_64x176x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %93, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN S32+=S8*U8 +struct MMA_64x208x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %109, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN S32+=S8*U8 +struct MMA_64x208x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %109, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN S32+=S8*U8 +struct MMA_64x224x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %117, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN S32+=S8*U8 +struct MMA_64x224x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %117, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN S32+=S8*U8 +struct MMA_64x240x32_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %125, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN S32+=S8*U8 +struct MMA_64x240x32_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %125, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN S32+=U8*S8 +struct MMA_64x24x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN S32+=U8*S8 +struct MMA_64x24x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN S32+=U8*S8 +struct MMA_64x48x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN S32+=U8*S8 +struct MMA_64x48x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN S32+=U8*S8 +struct MMA_64x80x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN S32+=U8*S8 +struct MMA_64x80x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN S32+=U8*S8 +struct MMA_64x112x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN S32+=U8*S8 +struct MMA_64x112x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN S32+=U8*S8 +struct MMA_64x144x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %74, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN S32+=U8*S8 +struct MMA_64x144x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %74, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN S32+=U8*S8 +struct MMA_64x160x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %82, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN S32+=U8*S8 +struct MMA_64x160x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %82, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN S32+=U8*S8 +struct MMA_64x176x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %90, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN S32+=U8*S8 +struct MMA_64x176x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %90, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN S32+=U8*S8 +struct MMA_64x208x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %106, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN S32+=U8*S8 +struct MMA_64x208x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %106, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN S32+=U8*S8 +struct MMA_64x224x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %114, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN S32+=U8*S8 +struct MMA_64x224x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %114, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN S32+=U8*S8 +struct MMA_64x240x32_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %122, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN S32+=U8*S8 +struct MMA_64x240x32_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %122, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN S32+=U8*S8 +struct MMA_64x24x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN S32+=U8*S8 +struct MMA_64x24x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN S32+=U8*S8 +struct MMA_64x48x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN S32+=U8*S8 +struct MMA_64x48x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN S32+=U8*S8 +struct MMA_64x80x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN S32+=U8*S8 +struct MMA_64x80x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN S32+=U8*S8 +struct MMA_64x112x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN S32+=U8*S8 +struct MMA_64x112x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN S32+=U8*S8 +struct MMA_64x144x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %77, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN S32+=U8*S8 +struct MMA_64x144x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %77, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN S32+=U8*S8 +struct MMA_64x160x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %85, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN S32+=U8*S8 +struct MMA_64x160x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %85, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN S32+=U8*S8 +struct MMA_64x176x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %93, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN S32+=U8*S8 +struct MMA_64x176x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %93, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN S32+=U8*S8 +struct MMA_64x208x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %109, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN S32+=U8*S8 +struct MMA_64x208x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %109, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN S32+=U8*S8 +struct MMA_64x224x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %117, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN S32+=U8*S8 +struct MMA_64x224x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %117, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN S32+=U8*S8 +struct MMA_64x240x32_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %125, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN S32+=U8*S8 +struct MMA_64x240x32_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %125, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN S32+=U8*U8 +struct MMA_64x24x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN S32+=U8*U8 +struct MMA_64x24x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN S32+=U8*U8 +struct MMA_64x48x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN S32+=U8*U8 +struct MMA_64x48x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN S32+=U8*U8 +struct MMA_64x80x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN S32+=U8*U8 +struct MMA_64x80x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN S32+=U8*U8 +struct MMA_64x112x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN S32+=U8*U8 +struct MMA_64x112x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN S32+=U8*U8 +struct MMA_64x144x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %74, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN S32+=U8*U8 +struct MMA_64x144x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %74, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN S32+=U8*U8 +struct MMA_64x160x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %82, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN S32+=U8*U8 +struct MMA_64x160x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %82, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN S32+=U8*U8 +struct MMA_64x176x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %90, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN S32+=U8*U8 +struct MMA_64x176x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %90, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN S32+=U8*U8 +struct MMA_64x208x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %106, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN S32+=U8*U8 +struct MMA_64x208x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %106, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN S32+=U8*U8 +struct MMA_64x224x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %114, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN S32+=U8*U8 +struct MMA_64x224x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %114, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN S32+=U8*U8 +struct MMA_64x240x32_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %122, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN S32+=U8*U8 +struct MMA_64x240x32_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %122, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN S32+=U8*U8 +struct MMA_64x24x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN S32+=U8*U8 +struct MMA_64x24x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN S32+=U8*U8 +struct MMA_64x48x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN S32+=U8*U8 +struct MMA_64x48x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN S32+=U8*U8 +struct MMA_64x80x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN S32+=U8*U8 +struct MMA_64x80x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN S32+=U8*U8 +struct MMA_64x112x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN S32+=U8*U8 +struct MMA_64x112x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN S32+=U8*U8 +struct MMA_64x144x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %77, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN S32+=U8*U8 +struct MMA_64x144x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %77, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN S32+=U8*U8 +struct MMA_64x160x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %85, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN S32+=U8*U8 +struct MMA_64x160x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %85, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN S32+=U8*U8 +struct MMA_64x176x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %93, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN S32+=U8*U8 +struct MMA_64x176x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %93, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN S32+=U8*U8 +struct MMA_64x208x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %109, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN S32+=U8*U8 +struct MMA_64x208x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %109, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN S32+=U8*U8 +struct MMA_64x224x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %117, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN S32+=U8*U8 +struct MMA_64x224x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %117, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN S32+=U8*U8 +struct MMA_64x240x32_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %125, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN S32+=U8*U8 +struct MMA_64x240x32_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %125, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5}," + " %6," + " %7," + " p, %9, %10;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5}," + "{%6, %7, %8, %9}," + " %10," + " p, %12, %13;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p, %15, %16;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p, %18, %19;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + " %10," + " %11," + " p, %13, %14;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + "{%10, %11, %12, %13}," + " %14," + " p, %16, %17;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " p, %23, %24;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " p, %26, %27;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p, %15, %16;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p, %18, %19;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p, %27, %28;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p, %30, %31;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + " %14," + " %15," + " p, %17, %18;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + "{%14, %15, %16, %17}," + " %18," + " p, %20, %21;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " p, %31, %32;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " p, %34, %35;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + " %18," + " %19," + " p, %21, %22;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + "{%18, %19, %20, %21}," + " %22," + " p, %24, %25;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " p, %39, %40;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " p, %42, %43;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " p, %23, %24;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " p, %26, %27;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p, %43, %44;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p, %46, %47;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + " %22," + " %23," + " p, %25, %26;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + "{%22, %23, %24, %25}," + " %26," + " p, %28, %29;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " p, %47, %48;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " p, %50, %51;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + " %26," + " %27," + " p, %29, %30;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + "{%26, %27, %28, %29}," + " %30," + " p, %32, %33;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " p, %55, %56;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " p, %58, %59;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " p, %31, %32;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " p, %34, %35;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p, %59, %60;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p, %62, %63;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + " %30," + " %31," + " p, %33, %34;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + "{%30, %31, %32, %33}," + " %34," + " p, %36, %37;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " p, %63, %64;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " p, %66, %67;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + " %34," + " %35," + " p, %37, %38;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + "{%34, %35, %36, %37}," + " %38," + " p, %40, %41;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %70, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + " %68," + " %69," + " p, %71, %72;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %73, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + "{%68, %69, %70, %71}," + " %72," + " p, %74, %75;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " p, %39, %40;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " p, %42, %43;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %74, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " p, %75, %76;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %77, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " p, %78, %79;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + " %38," + " %39," + " p, %41, %42;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + "{%38, %39, %40, %41}," + " %42," + " p, %44, %45;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %78, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + " %76," + " %77," + " p, %79, %80;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %81, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + "{%76, %77, %78, %79}," + " %80," + " p, %82, %83;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p, %43, %44;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p, %46, %47;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %82, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " p, %83, %84;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %85, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " p, %86, %87;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + " %42," + " %43," + " p, %45, %46;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + "{%42, %43, %44, %45}," + " %46," + " p, %48, %49;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %86, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + " %84," + " %85," + " p, %87, %88;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %89, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + "{%84, %85, %86, %87}," + " %88," + " p, %90, %91;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " p, %47, %48;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " p, %50, %51;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %90, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " p, %91, %92;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %93, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " p, %94, %95;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + " %46," + " %47," + " p, %49, %50;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + "{%46, %47, %48, %49}," + " %50," + " p, %52, %53;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %94, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + " %92," + " %93," + " p, %95, %96;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %97, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + "{%92, %93, %94, %95}," + " %96," + " p, %98, %99;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + " %50," + " %51," + " p, %53, %54;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + "{%50, %51, %52, %53}," + " %54," + " p, %56, %57;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %102, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + " %100," + " %101," + " p, %103, %104;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %105, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + "{%100, %101, %102, %103}," + " %104," + " p, %106, %107;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " p, %55, %56;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " p, %58, %59;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %106, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " p, %107, %108;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %109, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " p, %110, %111;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + " %54," + " %55," + " p, %57, %58;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + "{%54, %55, %56, %57}," + " %58," + " p, %60, %61;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %110, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + " %108," + " %109," + " p, %111, %112;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %113, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + "{%108, %109, %110, %111}," + " %112," + " p, %114, %115;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p, %59, %60;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p, %62, %63;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %114, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " p, %115, %116;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %117, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " p, %118, %119;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + " %58," + " %59," + " p, %61, %62;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + "{%58, %59, %60, %61}," + " %62," + " p, %64, %65;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %118, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + " %116," + " %117," + " p, %119, %120;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %121, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + "{%116, %117, %118, %119}," + " %120," + " p, %122, %123;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " p, %63, %64;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " p, %66, %67;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %122, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " p, %123, %124;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %125, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " p, %126, %127;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x32_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + " %62," + " %63," + " p, %65, %66;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x32_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x32 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x32_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k32.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + "{%62, %63, %64, %65}," + " %66," + " p, %68, %69;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x32_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x32_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %126, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + " %124," + " %125," + " p, %127, %128;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x32_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x32 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x32_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %129, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k32.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + "{%124, %125, %126, %127}," + " %128," + " p, %130, %131;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x32_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5}," + " %6," + " %7," + " p, %9, %10;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5}," + "{%6, %7, %8, %9}," + " %10," + " p, %12, %13;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p, %15, %16;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p, %18, %19;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + " %10," + " %11," + " p, %13, %14;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + "{%10, %11, %12, %13}," + " %14," + " p, %16, %17;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " p, %23, %24;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " p, %26, %27;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p, %15, %16;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p, %18, %19;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p, %27, %28;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p, %30, %31;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + " %14," + " %15," + " p, %17, %18;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + "{%14, %15, %16, %17}," + " %18," + " p, %20, %21;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " p, %31, %32;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " p, %34, %35;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + " %18," + " %19," + " p, %21, %22;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + "{%18, %19, %20, %21}," + " %22," + " p, %24, %25;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " p, %39, %40;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " p, %42, %43;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " p, %23, %24;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " p, %26, %27;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p, %43, %44;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p, %46, %47;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + " %22," + " %23," + " p, %25, %26;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + "{%22, %23, %24, %25}," + " %26," + " p, %28, %29;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " p, %47, %48;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " p, %50, %51;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + " %26," + " %27," + " p, %29, %30;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + "{%26, %27, %28, %29}," + " %30," + " p, %32, %33;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " p, %55, %56;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " p, %58, %59;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " p, %31, %32;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " p, %34, %35;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p, %59, %60;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p, %62, %63;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + " %30," + " %31," + " p, %33, %34;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + "{%30, %31, %32, %33}," + " %34," + " p, %36, %37;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " p, %63, %64;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " p, %66, %67;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + " %34," + " %35," + " p, %37, %38;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + "{%34, %35, %36, %37}," + " %38," + " p, %40, %41;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %70, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + " %68," + " %69," + " p, %71, %72;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %73, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + "{%68, %69, %70, %71}," + " %72," + " p, %74, %75;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " p, %39, %40;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " p, %42, %43;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %74, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " p, %75, %76;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %77, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " p, %78, %79;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + " %38," + " %39," + " p, %41, %42;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + "{%38, %39, %40, %41}," + " %42," + " p, %44, %45;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %78, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + " %76," + " %77," + " p, %79, %80;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %81, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + "{%76, %77, %78, %79}," + " %80," + " p, %82, %83;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p, %43, %44;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p, %46, %47;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %82, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " p, %83, %84;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %85, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " p, %86, %87;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + " %42," + " %43," + " p, %45, %46;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + "{%42, %43, %44, %45}," + " %46," + " p, %48, %49;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %86, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + " %84," + " %85," + " p, %87, %88;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %89, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + "{%84, %85, %86, %87}," + " %88," + " p, %90, %91;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " p, %47, %48;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " p, %50, %51;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %90, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " p, %91, %92;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %93, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " p, %94, %95;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + " %46," + " %47," + " p, %49, %50;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + "{%46, %47, %48, %49}," + " %50," + " p, %52, %53;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %94, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + " %92," + " %93," + " p, %95, %96;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %97, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + "{%92, %93, %94, %95}," + " %96," + " p, %98, %99;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + " %50," + " %51," + " p, %53, %54;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + "{%50, %51, %52, %53}," + " %54," + " p, %56, %57;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %102, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + " %100," + " %101," + " p, %103, %104;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %105, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + "{%100, %101, %102, %103}," + " %104," + " p, %106, %107;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " p, %55, %56;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " p, %58, %59;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %106, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " p, %107, %108;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %109, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " p, %110, %111;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + " %54," + " %55," + " p, %57, %58;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + "{%54, %55, %56, %57}," + " %58," + " p, %60, %61;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %110, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + " %108," + " %109," + " p, %111, %112;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %113, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + "{%108, %109, %110, %111}," + " %112," + " p, %114, %115;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p, %59, %60;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p, %62, %63;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %114, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " p, %115, %116;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %117, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " p, %118, %119;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + " %58," + " %59," + " p, %61, %62;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + "{%58, %59, %60, %61}," + " %62," + " p, %64, %65;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %118, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + " %116," + " %117," + " p, %119, %120;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %121, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + "{%116, %117, %118, %119}," + " %120," + " p, %122, %123;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " p, %63, %64;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " p, %66, %67;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %122, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " p, %123, %124;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %125, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " p, %126, %127;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x32_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + " %62," + " %63," + " p, %65, %66;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x32_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x32 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x32_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k32.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + "{%62, %63, %64, %65}," + " %66," + " p, %68, %69;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x32_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x32_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %126, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + " %124," + " %125," + " p, %127, %128;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x32_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x32 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x32_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %129, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k32.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + "{%124, %125, %126, %127}," + " %128," + " p, %130, %131;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x32_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5}," + " %6," + " %7," + " p, %9, %10;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5}," + "{%6, %7, %8, %9}," + " %10," + " p, %12, %13;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p, %15, %16;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p, %18, %19;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + " %10," + " %11," + " p, %13, %14;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + "{%10, %11, %12, %13}," + " %14," + " p, %16, %17;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " p, %23, %24;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " p, %26, %27;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p, %15, %16;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p, %18, %19;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p, %27, %28;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p, %30, %31;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + " %14," + " %15," + " p, %17, %18;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + "{%14, %15, %16, %17}," + " %18," + " p, %20, %21;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " p, %31, %32;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " p, %34, %35;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + " %18," + " %19," + " p, %21, %22;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + "{%18, %19, %20, %21}," + " %22," + " p, %24, %25;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " p, %39, %40;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " p, %42, %43;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " p, %23, %24;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " p, %26, %27;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p, %43, %44;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p, %46, %47;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + " %22," + " %23," + " p, %25, %26;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + "{%22, %23, %24, %25}," + " %26," + " p, %28, %29;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " p, %47, %48;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " p, %50, %51;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + " %26," + " %27," + " p, %29, %30;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + "{%26, %27, %28, %29}," + " %30," + " p, %32, %33;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " p, %55, %56;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " p, %58, %59;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " p, %31, %32;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " p, %34, %35;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p, %59, %60;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p, %62, %63;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + " %30," + " %31," + " p, %33, %34;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + "{%30, %31, %32, %33}," + " %34," + " p, %36, %37;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " p, %63, %64;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " p, %66, %67;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + " %34," + " %35," + " p, %37, %38;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + "{%34, %35, %36, %37}," + " %38," + " p, %40, %41;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %70, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + " %68," + " %69," + " p, %71, %72;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %73, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + "{%68, %69, %70, %71}," + " %72," + " p, %74, %75;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " p, %39, %40;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " p, %42, %43;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %74, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " p, %75, %76;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %77, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " p, %78, %79;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + " %38," + " %39," + " p, %41, %42;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + "{%38, %39, %40, %41}," + " %42," + " p, %44, %45;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %78, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + " %76," + " %77," + " p, %79, %80;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %81, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + "{%76, %77, %78, %79}," + " %80," + " p, %82, %83;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p, %43, %44;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p, %46, %47;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %82, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " p, %83, %84;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %85, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " p, %86, %87;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + " %42," + " %43," + " p, %45, %46;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + "{%42, %43, %44, %45}," + " %46," + " p, %48, %49;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %86, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + " %84," + " %85," + " p, %87, %88;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %89, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + "{%84, %85, %86, %87}," + " %88," + " p, %90, %91;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " p, %47, %48;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " p, %50, %51;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %90, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " p, %91, %92;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %93, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " p, %94, %95;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + " %46," + " %47," + " p, %49, %50;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + "{%46, %47, %48, %49}," + " %50," + " p, %52, %53;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %94, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + " %92," + " %93," + " p, %95, %96;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %97, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + "{%92, %93, %94, %95}," + " %96," + " p, %98, %99;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + " %50," + " %51," + " p, %53, %54;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + "{%50, %51, %52, %53}," + " %54," + " p, %56, %57;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %102, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + " %100," + " %101," + " p, %103, %104;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %105, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + "{%100, %101, %102, %103}," + " %104," + " p, %106, %107;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " p, %55, %56;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " p, %58, %59;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %106, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " p, %107, %108;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %109, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " p, %110, %111;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + " %54," + " %55," + " p, %57, %58;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + "{%54, %55, %56, %57}," + " %58," + " p, %60, %61;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %110, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + " %108," + " %109," + " p, %111, %112;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %113, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + "{%108, %109, %110, %111}," + " %112," + " p, %114, %115;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p, %59, %60;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p, %62, %63;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %114, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " p, %115, %116;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %117, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " p, %118, %119;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + " %58," + " %59," + " p, %61, %62;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + "{%58, %59, %60, %61}," + " %62," + " p, %64, %65;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %118, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + " %116," + " %117," + " p, %119, %120;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %121, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + "{%116, %117, %118, %119}," + " %120," + " p, %122, %123;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " p, %63, %64;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " p, %66, %67;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %122, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " p, %123, %124;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %125, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " p, %126, %127;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x32_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + " %62," + " %63," + " p, %65, %66;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x32_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x32 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x32_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k32.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + "{%62, %63, %64, %65}," + " %66," + " p, %68, %69;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x32_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x32_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %126, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + " %124," + " %125," + " p, %127, %128;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x32_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x32 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x32_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %129, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k32.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + "{%124, %125, %126, %127}," + " %128," + " p, %130, %131;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x32_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5}," + " %6," + " %7," + " p, %9, %10;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5}," + "{%6, %7, %8, %9}," + " %10," + " p, %12, %13;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p, %15, %16;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x24x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x24x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n24k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p, %18, %19;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x24x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + " %10," + " %11," + " p, %13, %14;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + "{%10, %11, %12, %13}," + " %14," + " p, %16, %17;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " p, %23, %24;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x40x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x40x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sync.aligned.m64n40k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " p, %26, %27;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x40x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " p, %15, %16;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " p, %18, %19;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " p, %27, %28;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x48x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x48x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sync.aligned.m64n48k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " p, %30, %31;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x48x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + " %14," + " %15," + " p, %17, %18;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + "{%14, %15, %16, %17}," + " %18," + " p, %20, %21;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " p, %31, %32;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x56x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x56x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sync.aligned.m64n56k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " p, %34, %35;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x56x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + " %18," + " %19," + " p, %21, %22;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + "{%18, %19, %20, %21}," + " %22," + " p, %24, %25;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " p, %39, %40;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x72x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x72x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sync.aligned.m64n72k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " p, %42, %43;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x72x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " p, %23, %24;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " p, %26, %27;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p, %43, %44;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x80x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x80x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n80k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p, %46, %47;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x80x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + " %22," + " %23," + " p, %25, %26;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + "{%22, %23, %24, %25}," + " %26," + " p, %28, %29;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " p, %47, %48;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x88x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x88x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sync.aligned.m64n88k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " p, %50, %51;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x88x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + " %26," + " %27," + " p, %29, %30;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + "{%26, %27, %28, %29}," + " %30," + " p, %32, %33;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " p, %55, %56;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x104x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x104x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sync.aligned.m64n104k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " p, %58, %59;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x104x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " p, %31, %32;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " p, %34, %35;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p, %59, %60;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x112x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x112x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n112k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p, %62, %63;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x112x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + " %30," + " %31," + " p, %33, %34;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + "{%30, %31, %32, %33}," + " %34," + " p, %36, %37;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " p, %63, %64;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x120x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x120x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sync.aligned.m64n120k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " p, %66, %67;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x120x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + " %34," + " %35," + " p, %37, %38;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + "{%34, %35, %36, %37}," + " %38," + " p, %40, %41;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %70, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + " %68," + " %69," + " p, %71, %72;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x136x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x136x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %73, 0;\n" + "wgmma.mma_async.sync.aligned.m64n136k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + "{%68, %69, %70, %71}," + " %72," + " p, %74, %75;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x136x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " p, %39, %40;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " p, %42, %43;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %74, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " p, %75, %76;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x144x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x144x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %77, 0;\n" + "wgmma.mma_async.sync.aligned.m64n144k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " p, %78, %79;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x144x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + " %38," + " %39," + " p, %41, %42;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + "{%38, %39, %40, %41}," + " %42," + " p, %44, %45;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %78, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + " %76," + " %77," + " p, %79, %80;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x152x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x152x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %81, 0;\n" + "wgmma.mma_async.sync.aligned.m64n152k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + "{%76, %77, %78, %79}," + " %80," + " p, %82, %83;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x152x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " p, %43, %44;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " p, %46, %47;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %82, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " p, %83, %84;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x160x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x160x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %85, 0;\n" + "wgmma.mma_async.sync.aligned.m64n160k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " p, %86, %87;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x160x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + " %42," + " %43," + " p, %45, %46;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + "{%42, %43, %44, %45}," + " %46," + " p, %48, %49;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %86, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + " %84," + " %85," + " p, %87, %88;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x168x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x168x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %89, 0;\n" + "wgmma.mma_async.sync.aligned.m64n168k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + "{%84, %85, %86, %87}," + " %88," + " p, %90, %91;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x168x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " p, %47, %48;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " p, %50, %51;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %90, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " p, %91, %92;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x176x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x176x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %93, 0;\n" + "wgmma.mma_async.sync.aligned.m64n176k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " p, %94, %95;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x176x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + " %46," + " %47," + " p, %49, %50;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + "{%46, %47, %48, %49}," + " %50," + " p, %52, %53;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %94, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + " %92," + " %93," + " p, %95, %96;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x184x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x184x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %97, 0;\n" + "wgmma.mma_async.sync.aligned.m64n184k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + "{%92, %93, %94, %95}," + " %96," + " p, %98, %99;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x184x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + " %50," + " %51," + " p, %53, %54;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + "{%50, %51, %52, %53}," + " %54," + " p, %56, %57;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %102, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + " %100," + " %101," + " p, %103, %104;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x200x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x200x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %105, 0;\n" + "wgmma.mma_async.sync.aligned.m64n200k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + "{%100, %101, %102, %103}," + " %104," + " p, %106, %107;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x200x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " p, %55, %56;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " p, %58, %59;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %106, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " p, %107, %108;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x208x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x208x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %109, 0;\n" + "wgmma.mma_async.sync.aligned.m64n208k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " p, %110, %111;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x208x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + " %54," + " %55," + " p, %57, %58;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + "{%54, %55, %56, %57}," + " %58," + " p, %60, %61;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %110, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + " %108," + " %109," + " p, %111, %112;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x216x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x216x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %113, 0;\n" + "wgmma.mma_async.sync.aligned.m64n216k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + "{%108, %109, %110, %111}," + " %112," + " p, %114, %115;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x216x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " p, %59, %60;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " p, %62, %63;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %114, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " p, %115, %116;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x224x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x224x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %117, 0;\n" + "wgmma.mma_async.sync.aligned.m64n224k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " p, %118, %119;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x224x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + " %58," + " %59," + " p, %61, %62;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + "{%58, %59, %60, %61}," + " %62," + " p, %64, %65;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %118, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + " %116," + " %117," + " p, %119, %120;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x232x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x232x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %121, 0;\n" + "wgmma.mma_async.sync.aligned.m64n232k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + "{%116, %117, %118, %119}," + " %120," + " p, %122, %123;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x232x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " p, %63, %64;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " p, %66, %67;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %122, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " p, %123, %124;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x240x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x240x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %125, 0;\n" + "wgmma.mma_async.sync.aligned.m64n240k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " p, %126, %127;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x240x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x32_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + " %62," + " %63," + " p, %65, %66;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x32_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x32 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x32_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k32.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + "{%62, %63, %64, %65}," + " %66," + " p, %68, %69;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x32_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x32_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %126, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + " %124," + " %125," + " p, %127, %128;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "l"(desc_a), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x32_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA 64x248x32 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +struct MMA_64x248x32_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %129, 0;\n" + "wgmma.mma_async.sync.aligned.m64n248k32.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + "{%124, %125, %126, %127}," + " %128," + " p, %130, %131;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use MMA_64x248x32_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace SM90::GMMA + +} // namespace cute diff --git a/include/cute/arch/mma_sm90_gmma_sparse.hpp b/include/cute/arch/mma_sm90_gmma_sparse.hpp new file mode 100644 index 0000000000..ecca91b93c --- /dev/null +++ b/include/cute/arch/mma_sm90_gmma_sparse.hpp @@ -0,0 +1,22743 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include // CUTE_HOST_DEVICE +#include // GMMA::Major, etc. + +namespace cute { + +namespace SM90::GMMA::SPARSE { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// GMMA PTX definitions: C = (scaleA * A) * (scaleB * B) + (scaleD * C) +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k32.f16.f16.f16 " + "{%0, %1}," + " %2," + " %3," + " %4, %5," + " p, %7, %8, %9, %10;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k32.f16.f16.f16 " + "{%0, %1}," + "{%2, %3, %4, %5}," + " %6," + " %7, %8," + " p, %10, %11, %12;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k32.f16.f16.f16 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10, %11, %12;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k32.f16.f16.f16 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13, %14;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14, %15, %16;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17, %18;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22, %23, %24;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25, %26;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30, %31, %32;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33, %34;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38, %39, %40;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41, %42;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54, %55, %56;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57, %58;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70, %71, %72;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73, %74;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k32.f32.f16.f16 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10, %11, %12;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k32.f32.f16.f16 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13, %14;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14, %15, %16;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17, %18;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22, %23, %24;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25, %26;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38, %39, %40;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41, %42;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54, %55, %56;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57, %58;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70, %71, %72;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73, %74;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p, %101, %102, %103, %104;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p, %104, %105, %106;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p, %133, %134, %135, %136;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p, %136, %137, %138;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k32.f32.bf16.bf16 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10, %11, %12;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k32.f32.bf16.bf16 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13, %14;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14, %15, %16;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17, %18;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22, %23, %24;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25, %26;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38, %39, %40;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41, %42;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54, %55, %56;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57, %58;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70, %71, %72;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73, %74;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p, %101, %102, %103, %104;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p, %104, %105, %106;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p, %133, %134, %135, %136;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p, %136, %137, %138;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k16.f32.tf32.tf32 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k16.f32.tf32.tf32 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p, %101, %102;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p, %104, %105;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p, %133, %134;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p, %136, %137;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.s8.s8 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.s8.s8 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.s8.u8 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.s8.u8 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.u8.s8 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.u8.s8 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.u8.u8 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.u8.u8 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t & d88, uint32_t & d89, uint32_t & d90, uint32_t & d91, + uint32_t & d92, uint32_t & d93, uint32_t & d94, uint32_t & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87), + "+r"(d88), "+r"(d89), "+r"(d90), "+r"(d91), + "+r"(d92), "+r"(d93), "+r"(d94), "+r"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t & d120, uint32_t & d121, uint32_t & d122, uint32_t & d123, + uint32_t & d124, uint32_t & d125, uint32_t & d126, uint32_t & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119), + "+r"(d120), "+r"(d121), "+r"(d122), "+r"(d123), + "+r"(d124), "+r"(d125), "+r"(d126), "+r"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f16.e4m3.e4m3 " + "{%0, %1}," + " %2," + " %3," + " %4, %5," + " p, %7, %8;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f16.e4m3.e4m3 " + "{%0, %1}," + "{%2, %3, %4, %5}," + " %6," + " %7, %8," + " p, %10, %11;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p, %101, %102;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p, %104, %105;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p, %133, %134;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p, %136, %137;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f16.e4m3.e5m2 " + "{%0, %1}," + " %2," + " %3," + " %4, %5," + " p, %7, %8;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f16.e4m3.e5m2 " + "{%0, %1}," + "{%2, %3, %4, %5}," + " %6," + " %7, %8," + " p, %10, %11;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p, %101, %102;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p, %104, %105;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p, %133, %134;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p, %136, %137;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f16.e5m2.e4m3 " + "{%0, %1}," + " %2," + " %3," + " %4, %5," + " p, %7, %8;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f16.e5m2.e4m3 " + "{%0, %1}," + "{%2, %3, %4, %5}," + " %6," + " %7, %8," + " p, %10, %11;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p, %101, %102;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p, %104, %105;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p, %133, %134;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p, %136, %137;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %6, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f16.e5m2.e5m2 " + "{%0, %1}," + " %2," + " %3," + " %4, %5," + " p, %7, %8;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[2]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %9, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f16.e5m2.e5m2 " + "{%0, %1}," + "{%2, %3, %4, %5}," + " %6," + " %7, %8," + " p, %10, %11;\n" + "}\n" + : "+r"(d0), "+r"(d1) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x8x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x8x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n8k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x8x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %8, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3}," + " %4," + " %5," + " %6, %7," + " p, %9, %10;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[4]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %11, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3}," + "{%4, %5, %6, %7}," + " %8," + " %9, %10," + " p, %12, %13;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x16x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x16x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + float & d0, float & d1, float & d2, float & d3, + float & d4, float & d5, float & d6, float & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n16k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17;\n" + "}\n" + : "+f"(d0), "+f"(d1), "+f"(d2), "+f"(d3), + "+f"(d4), "+f"(d5), "+f"(d6), "+f"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x16x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %12, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + " %8," + " %9," + " %10, %11," + " p, %13, %14;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[8]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, uint32_t & d6, uint32_t & d7, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %15, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7}," + "{%8, %9, %10, %11}," + " %12," + " %13, %14," + " p, %16, %17;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5), "+r"(d6), "+r"(d7) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x32x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x32x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n32k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x32x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %20, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + " %16," + " %17," + " %18, %19," + " p, %21, %22;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[16]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %23, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15}," + "{%16, %17, %18, %19}," + " %20," + " %21, %22," + " p, %24, %25;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x64x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x64x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n64k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x64x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x96x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x96x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n96k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x96x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %36, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + " %32," + " %33," + " %34, %35," + " p, %37, %38;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[32]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %39, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31}," + "{%32, %33, %34, %35}," + " %36," + " %37, %38," + " p, %40, %41;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x128x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x128x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n128k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x128x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %52, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + " %48," + " %49," + " %50, %51," + " p, %53, %54;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[48]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %55, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47}," + "{%48, %49, %50, %51}," + " %52," + " %53, %54," + " p, %56, %57;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %100, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + " %96," + " %97," + " %98, %99," + " p, %101, %102;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x192x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x192x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[96]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + float & d92, float & d93, float & d94, float & d95, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %103, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n192k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95}," + "{%96, %97, %98, %99}," + " %100," + " %101, %102," + " p, %104, %105;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91), + "+f"(d92), "+f"(d93), "+f"(d94), "+f"(d95) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x192x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %68, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + " %64," + " %65," + " %66, %67," + " p, %69, %70;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[64]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %71, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63}," + "{%64, %65, %66, %67}," + " %68," + " %69, %70," + " p, %72, %73;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %132, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + " %128," + " %129," + " %130, %131," + " p, %133, %134;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x256x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x256x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[128]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + float & d124, float & d125, float & d126, float & d127, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %135, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n256k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123, %124, %125, %126, %127}," + "{%128, %129, %130, %131}," + " %132," + " %133, %134," + " p, %136, %137;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123), + "+f"(d124), "+f"(d125), "+f"(d126), "+f"(d127) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x256x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace SM90::GMMA::SPARSE + +} // namespace cute + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +#include "mma_sm90_gmma_sparse_ext.hpp" +#endif diff --git a/include/cute/arch/mma_sm90_gmma_sparse_ext.hpp b/include/cute/arch/mma_sm90_gmma_sparse_ext.hpp new file mode 100644 index 0000000000..c224e4034e --- /dev/null +++ b/include/cute/arch/mma_sm90_gmma_sparse_ext.hpp @@ -0,0 +1,60445 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include // CUTE_HOST_DEVICE + +#include "cutlass/arch/synclog.hpp" + +// Config +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 900) && defined(__CUDA_ARCH_FEAT_SM90_ALL)) +# define CUTE_ARCH_MMA_SM90A_ENABLED +#endif + +namespace cute { + +namespace SM90::GMMA::SPARSE { + +// SPARSE GMMA 64x24x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5}," + " %6," + " %7," + " %8, %9," + " p, %11, %12, %13, %14;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5}," + "{%6, %7, %8, %9}," + " %10," + " %11, %12," + " p, %14, %15, %16;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + " %10," + " %11," + " %12, %13," + " p, %15, %16, %17, %18;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + "{%10, %11, %12, %13}," + " %14," + " %15, %16," + " p, %18, %19, %20;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p, %17, %18, %19, %20;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p, %20, %21, %22;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + " %14," + " %15," + " %16, %17," + " p, %19, %20, %21, %22;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + "{%14, %15, %16, %17}," + " %18," + " %19, %20," + " p, %22, %23, %24;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + " %18," + " %19," + " %20, %21," + " p, %23, %24, %25, %26;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + "{%18, %19, %20, %21}," + " %22," + " %23, %24," + " p, %26, %27, %28;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " %22, %23," + " p, %25, %26, %27, %28;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " %25, %26," + " p, %28, %29, %30;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + " %22," + " %23," + " %24, %25," + " p, %27, %28, %29, %30;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + "{%22, %23, %24, %25}," + " %26," + " %27, %28," + " p, %30, %31, %32;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + " %26," + " %27," + " %28, %29," + " p, %31, %32, %33, %34;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + "{%26, %27, %28, %29}," + " %30," + " %31, %32," + " p, %34, %35, %36;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " %30, %31," + " p, %33, %34, %35, %36;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " %33, %34," + " p, %36, %37, %38;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + " %30," + " %31," + " %32, %33," + " p, %35, %36, %37, %38;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + "{%30, %31, %32, %33}," + " %34," + " %35, %36," + " p, %38, %39, %40;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + " %34," + " %35," + " %36, %37," + " p, %39, %40, %41, %42;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + "{%34, %35, %36, %37}," + " %38," + " %39, %40," + " p, %42, %43, %44;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " %38, %39," + " p, %41, %42, %43, %44;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " %41, %42," + " p, %44, %45, %46;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + " %38," + " %39," + " %40, %41," + " p, %43, %44, %45, %46;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + "{%38, %39, %40, %41}," + " %42," + " %43, %44," + " p, %46, %47, %48;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46, %47, %48;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49, %50;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + " %42," + " %43," + " %44, %45," + " p, %47, %48, %49, %50;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + "{%42, %43, %44, %45}," + " %46," + " %47, %48," + " p, %50, %51, %52;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " %46, %47," + " p, %49, %50, %51, %52;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " %49, %50," + " p, %52, %53, %54;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + " %46," + " %47," + " %48, %49," + " p, %51, %52, %53, %54;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + "{%46, %47, %48, %49}," + " %50," + " %51, %52," + " p, %54, %55, %56;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + " %50," + " %51," + " %52, %53," + " p, %55, %56, %57, %58;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + "{%50, %51, %52, %53}," + " %54," + " %55, %56," + " p, %58, %59, %60;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " %54, %55," + " p, %57, %58, %59, %60;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " %57, %58," + " p, %60, %61, %62;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + " %54," + " %55," + " %56, %57," + " p, %59, %60, %61, %62;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + "{%54, %55, %56, %57}," + " %58," + " %59, %60," + " p, %62, %63, %64;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62, %63, %64;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65, %66;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + " %58," + " %59," + " %60, %61," + " p, %63, %64, %65, %66;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + "{%58, %59, %60, %61}," + " %62," + " %63, %64," + " p, %66, %67, %68;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " %62, %63," + " p, %65, %66, %67, %68;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " %65, %66," + " p, %68, %69, %70;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x32_F16F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + " %62," + " %63," + " %64, %65," + " p, %67, %68, %69, %70;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x32_F16F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x32 F16+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x32_F16F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k32.f16.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + "{%62, %63, %64, %65}," + " %66," + " %67, %68," + " p, %70, %71, %72;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x32_F16F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p, %17, %18, %19, %20;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p, %20, %21, %22;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " %22, %23," + " p, %25, %26, %27, %28;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " %25, %26," + " p, %28, %29, %30;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30, %31, %32;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33, %34;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " %30, %31," + " p, %33, %34, %35, %36;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " %33, %34," + " p, %36, %37, %38;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " %38, %39," + " p, %41, %42, %43, %44;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " %41, %42," + " p, %44, %45, %46;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46, %47, %48;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49, %50;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " %46, %47," + " p, %49, %50, %51, %52;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " %49, %50," + " p, %52, %53, %54;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " %54, %55," + " p, %57, %58, %59, %60;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " %57, %58," + " p, %60, %61, %62;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62, %63, %64;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65, %66;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " %62, %63," + " p, %65, %66, %67, %68;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " %65, %66," + " p, %68, %69, %70;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %72, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + " %68," + " %69," + " %70, %71," + " p, %73, %74, %75, %76;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %75, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + "{%68, %69, %70, %71}," + " %72," + " %73, %74," + " p, %76, %77, %78;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p, %77, %78, %79, %80;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p, %80, %81, %82;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %80, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + " %76," + " %77," + " %78, %79," + " p, %81, %82, %83, %84;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %83, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + "{%76, %77, %78, %79}," + " %80," + " %81, %82," + " p, %84, %85, %86;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p, %85, %86, %87, %88;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p, %88, %89, %90;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %88, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + " %84," + " %85," + " %86, %87," + " p, %89, %90, %91, %92;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %91, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + "{%84, %85, %86, %87}," + " %88," + " %89, %90," + " p, %92, %93, %94;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p, %93, %94, %95, %96;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p, %96, %97, %98;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %96, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + " %92," + " %93," + " %94, %95," + " p, %97, %98, %99, %100;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %99, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + "{%92, %93, %94, %95}," + " %96," + " %97, %98," + " p, %100, %101, %102;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %104, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + " %100," + " %101," + " %102, %103," + " p, %105, %106, %107, %108;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %107, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + "{%100, %101, %102, %103}," + " %104," + " %105, %106," + " p, %108, %109, %110;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p, %109, %110, %111, %112;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p, %112, %113, %114;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %112, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + " %108," + " %109," + " %110, %111," + " p, %113, %114, %115, %116;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %115, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + "{%108, %109, %110, %111}," + " %112," + " %113, %114," + " p, %116, %117, %118;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p, %117, %118, %119, %120;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p, %120, %121, %122;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %120, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + " %116," + " %117," + " %118, %119," + " p, %121, %122, %123, %124;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %123, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + "{%116, %117, %118, %119}," + " %120," + " %121, %122," + " p, %124, %125, %126;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p, %125, %126, %127, %128;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p, %128, %129, %130;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x32_F32F16F16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %128, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + " %124," + " %125," + " %126, %127," + " p, %129, %130, %131, %132;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x32_F32F16F16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x32 F32+=F16*F16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x32_F32F16F16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %131, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k32.f32.f16.f16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + "{%124, %125, %126, %127}," + " %128," + " %129, %130," + " p, %132, %133, %134;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x32_F32F16F16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p, %17, %18, %19, %20;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p, %20, %21, %22;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " %22, %23," + " p, %25, %26, %27, %28;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " %25, %26," + " p, %28, %29, %30;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30, %31, %32;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33, %34;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " %30, %31," + " p, %33, %34, %35, %36;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " %33, %34," + " p, %36, %37, %38;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " %38, %39," + " p, %41, %42, %43, %44;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " %41, %42," + " p, %44, %45, %46;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46, %47, %48;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49, %50;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " %46, %47," + " p, %49, %50, %51, %52;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " %49, %50," + " p, %52, %53, %54;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " %54, %55," + " p, %57, %58, %59, %60;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " %57, %58," + " p, %60, %61, %62;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62, %63, %64;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65, %66;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " %62, %63," + " p, %65, %66, %67, %68;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " %65, %66," + " p, %68, %69, %70;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %72, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + " %68," + " %69," + " %70, %71," + " p, %73, %74, %75, %76;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %75, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + "{%68, %69, %70, %71}," + " %72," + " %73, %74," + " p, %76, %77, %78;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p, %77, %78, %79, %80;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p, %80, %81, %82;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %80, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + " %76," + " %77," + " %78, %79," + " p, %81, %82, %83, %84;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %83, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + "{%76, %77, %78, %79}," + " %80," + " %81, %82," + " p, %84, %85, %86;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p, %85, %86, %87, %88;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p, %88, %89, %90;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %88, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + " %84," + " %85," + " %86, %87," + " p, %89, %90, %91, %92;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %91, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + "{%84, %85, %86, %87}," + " %88," + " %89, %90," + " p, %92, %93, %94;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p, %93, %94, %95, %96;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p, %96, %97, %98;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %96, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + " %92," + " %93," + " %94, %95," + " p, %97, %98, %99, %100;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %99, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + "{%92, %93, %94, %95}," + " %96," + " %97, %98," + " p, %100, %101, %102;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %104, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + " %100," + " %101," + " %102, %103," + " p, %105, %106, %107, %108;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %107, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + "{%100, %101, %102, %103}," + " %104," + " %105, %106," + " p, %108, %109, %110;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p, %109, %110, %111, %112;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p, %112, %113, %114;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %112, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + " %108," + " %109," + " %110, %111," + " p, %113, %114, %115, %116;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %115, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + "{%108, %109, %110, %111}," + " %112," + " %113, %114," + " p, %116, %117, %118;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p, %117, %118, %119, %120;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p, %120, %121, %122;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %120, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + " %116," + " %117," + " %118, %119," + " p, %121, %122, %123, %124;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %123, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + "{%116, %117, %118, %119}," + " %120," + " %121, %122," + " p, %124, %125, %126;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p, %125, %126, %127, %128;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p, %128, %129, %130;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x32_F32BF16BF16_SS +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %128, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + " %124," + " %125," + " %126, %127," + " p, %129, %130, %131, %132;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspA)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x32_F32BF16BF16_SS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x32 F32+=BF16*BF16 +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x32_F32BF16BF16_RS +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + static_assert(tnspA == GMMA::Major::K, + "Register source operand A must have K major layout."); + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %131, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k32.f32.bf16.bf16 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + "{%124, %125, %126, %127}," + " %128," + " %129, %130," + " p, %132, %133, %134;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB)), "n"(int32_t(tnspB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x32_F32BF16BF16_RS without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p, %17, %18;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p, %20, %21;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " %22, %23," + " p, %25, %26;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " %25, %26," + " p, %28, %29;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " %30, %31," + " p, %33, %34;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " %33, %34," + " p, %36, %37;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " %38, %39," + " p, %41, %42;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " %41, %42," + " p, %44, %45;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " %46, %47," + " p, %49, %50;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " %49, %50," + " p, %52, %53;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " %54, %55," + " p, %57, %58;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " %57, %58," + " p, %60, %61;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " %62, %63," + " p, %65, %66;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " %65, %66," + " p, %68, %69;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %72, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + " %68," + " %69," + " %70, %71," + " p, %73, %74;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %75, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + "{%68, %69, %70, %71}," + " %72," + " %73, %74," + " p, %76, %77;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p, %77, %78;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p, %80, %81;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %80, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + " %76," + " %77," + " %78, %79," + " p, %81, %82;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %83, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + "{%76, %77, %78, %79}," + " %80," + " %81, %82," + " p, %84, %85;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p, %85, %86;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p, %88, %89;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %88, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + " %84," + " %85," + " %86, %87," + " p, %89, %90;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %91, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + "{%84, %85, %86, %87}," + " %88," + " %89, %90," + " p, %92, %93;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p, %93, %94;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p, %96, %97;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %96, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + " %92," + " %93," + " %94, %95," + " p, %97, %98;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %99, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + "{%92, %93, %94, %95}," + " %96," + " %97, %98," + " p, %100, %101;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %104, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + " %100," + " %101," + " %102, %103," + " p, %105, %106;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %107, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + "{%100, %101, %102, %103}," + " %104," + " %105, %106," + " p, %108, %109;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p, %109, %110;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p, %112, %113;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %112, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + " %108," + " %109," + " %110, %111," + " p, %113, %114;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %115, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + "{%108, %109, %110, %111}," + " %112," + " %113, %114," + " p, %116, %117;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p, %117, %118;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p, %120, %121;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %120, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + " %116," + " %117," + " %118, %119," + " p, %121, %122;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %123, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + "{%116, %117, %118, %119}," + " %120," + " %121, %122," + " p, %124, %125;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p, %125, %126;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p, %128, %129;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x16_F32TF32TF32_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %128, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + " %124," + " %125," + " %126, %127," + " p, %129, %130;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x16_F32TF32TF32_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x16 TN F32+=TF32*TF32 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x16_F32TF32TF32_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %131, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k16.f32.tf32.tf32 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + "{%124, %125, %126, %127}," + " %128," + " %129, %130," + " p, %132, %133;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x16_F32TF32TF32_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32S8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32S8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32S8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.s8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN S32+=S8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32S8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.s8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32S8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32S8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32S8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.s8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN S32+=S8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32S8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.s8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32S8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32U8S8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8S8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32U8S8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8S8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32U8S8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.u8.s8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8S8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN S32+=U8*S8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32U8S8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.u8.s8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8S8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32U8U8_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8U8_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32U8U8_SS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8U8_SS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, uint32_t & d62, uint32_t & d63, + uint32_t & d64, uint32_t & d65, uint32_t & d66, uint32_t & d67, + uint32_t & d68, uint32_t & d69, uint32_t & d70, uint32_t & d71, + uint32_t & d72, uint32_t & d73, uint32_t & d74, uint32_t & d75, + uint32_t & d76, uint32_t & d77, uint32_t & d78, uint32_t & d79, + uint32_t & d80, uint32_t & d81, uint32_t & d82, uint32_t & d83, + uint32_t & d84, uint32_t & d85, uint32_t & d86, uint32_t & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61), "+r"(d62), "+r"(d63), + "+r"(d64), "+r"(d65), "+r"(d66), "+r"(d67), + "+r"(d68), "+r"(d69), "+r"(d70), "+r"(d71), + "+r"(d72), "+r"(d73), "+r"(d74), "+r"(d75), + "+r"(d76), "+r"(d77), "+r"(d78), "+r"(d79), + "+r"(d80), "+r"(d81), "+r"(d82), "+r"(d83), + "+r"(d84), "+r"(d85), "+r"(d86), "+r"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32U8U8_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.u8.u8 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8U8_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN S32+=U8*U8 +template < + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_S32U8U8_RS_TN_SATURATE +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + uint32_t & d000, uint32_t & d001, uint32_t & d002, uint32_t & d003, + uint32_t & d004, uint32_t & d005, uint32_t & d006, uint32_t & d007, + uint32_t & d008, uint32_t & d009, uint32_t & d010, uint32_t & d011, + uint32_t & d012, uint32_t & d013, uint32_t & d014, uint32_t & d015, + uint32_t & d016, uint32_t & d017, uint32_t & d018, uint32_t & d019, + uint32_t & d020, uint32_t & d021, uint32_t & d022, uint32_t & d023, + uint32_t & d024, uint32_t & d025, uint32_t & d026, uint32_t & d027, + uint32_t & d028, uint32_t & d029, uint32_t & d030, uint32_t & d031, + uint32_t & d032, uint32_t & d033, uint32_t & d034, uint32_t & d035, + uint32_t & d036, uint32_t & d037, uint32_t & d038, uint32_t & d039, + uint32_t & d040, uint32_t & d041, uint32_t & d042, uint32_t & d043, + uint32_t & d044, uint32_t & d045, uint32_t & d046, uint32_t & d047, + uint32_t & d048, uint32_t & d049, uint32_t & d050, uint32_t & d051, + uint32_t & d052, uint32_t & d053, uint32_t & d054, uint32_t & d055, + uint32_t & d056, uint32_t & d057, uint32_t & d058, uint32_t & d059, + uint32_t & d060, uint32_t & d061, uint32_t & d062, uint32_t & d063, + uint32_t & d064, uint32_t & d065, uint32_t & d066, uint32_t & d067, + uint32_t & d068, uint32_t & d069, uint32_t & d070, uint32_t & d071, + uint32_t & d072, uint32_t & d073, uint32_t & d074, uint32_t & d075, + uint32_t & d076, uint32_t & d077, uint32_t & d078, uint32_t & d079, + uint32_t & d080, uint32_t & d081, uint32_t & d082, uint32_t & d083, + uint32_t & d084, uint32_t & d085, uint32_t & d086, uint32_t & d087, + uint32_t & d088, uint32_t & d089, uint32_t & d090, uint32_t & d091, + uint32_t & d092, uint32_t & d093, uint32_t & d094, uint32_t & d095, + uint32_t & d096, uint32_t & d097, uint32_t & d098, uint32_t & d099, + uint32_t & d100, uint32_t & d101, uint32_t & d102, uint32_t & d103, + uint32_t & d104, uint32_t & d105, uint32_t & d106, uint32_t & d107, + uint32_t & d108, uint32_t & d109, uint32_t & d110, uint32_t & d111, + uint32_t & d112, uint32_t & d113, uint32_t & d114, uint32_t & d115, + uint32_t & d116, uint32_t & d117, uint32_t & d118, uint32_t & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.s32.u8.u8.satfinite " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p;\n" + "}\n" + : "+r"(d000), "+r"(d001), "+r"(d002), "+r"(d003), + "+r"(d004), "+r"(d005), "+r"(d006), "+r"(d007), + "+r"(d008), "+r"(d009), "+r"(d010), "+r"(d011), + "+r"(d012), "+r"(d013), "+r"(d014), "+r"(d015), + "+r"(d016), "+r"(d017), "+r"(d018), "+r"(d019), + "+r"(d020), "+r"(d021), "+r"(d022), "+r"(d023), + "+r"(d024), "+r"(d025), "+r"(d026), "+r"(d027), + "+r"(d028), "+r"(d029), "+r"(d030), "+r"(d031), + "+r"(d032), "+r"(d033), "+r"(d034), "+r"(d035), + "+r"(d036), "+r"(d037), "+r"(d038), "+r"(d039), + "+r"(d040), "+r"(d041), "+r"(d042), "+r"(d043), + "+r"(d044), "+r"(d045), "+r"(d046), "+r"(d047), + "+r"(d048), "+r"(d049), "+r"(d050), "+r"(d051), + "+r"(d052), "+r"(d053), "+r"(d054), "+r"(d055), + "+r"(d056), "+r"(d057), "+r"(d058), "+r"(d059), + "+r"(d060), "+r"(d061), "+r"(d062), "+r"(d063), + "+r"(d064), "+r"(d065), "+r"(d066), "+r"(d067), + "+r"(d068), "+r"(d069), "+r"(d070), "+r"(d071), + "+r"(d072), "+r"(d073), "+r"(d074), "+r"(d075), + "+r"(d076), "+r"(d077), "+r"(d078), "+r"(d079), + "+r"(d080), "+r"(d081), "+r"(d082), "+r"(d083), + "+r"(d084), "+r"(d085), "+r"(d086), "+r"(d087), + "+r"(d088), "+r"(d089), "+r"(d090), "+r"(d091), + "+r"(d092), "+r"(d093), "+r"(d094), "+r"(d095), + "+r"(d096), "+r"(d097), "+r"(d098), "+r"(d099), + "+r"(d100), "+r"(d101), "+r"(d102), "+r"(d103), + "+r"(d104), "+r"(d105), "+r"(d106), "+r"(d107), + "+r"(d108), "+r"(d109), "+r"(d110), "+r"(d111), + "+r"(d112), "+r"(d113), "+r"(d114), "+r"(d115), + "+r"(d116), "+r"(d117), "+r"(d118), "+r"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_S32U8U8_RS_TN_SATURATE without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5}," + " %6," + " %7," + " %8, %9," + " p, %11, %12;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5}," + "{%6, %7, %8, %9}," + " %10," + " %11, %12," + " p, %14, %15;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p, %17, %18;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p, %20, %21;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + " %10," + " %11," + " %12, %13," + " p, %15, %16;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + "{%10, %11, %12, %13}," + " %14," + " %15, %16," + " p, %18, %19;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " %22, %23," + " p, %25, %26;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " %25, %26," + " p, %28, %29;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p, %17, %18;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p, %20, %21;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + " %14," + " %15," + " %16, %17," + " p, %19, %20;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + "{%14, %15, %16, %17}," + " %18," + " %19, %20," + " p, %22, %23;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " %30, %31," + " p, %33, %34;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " %33, %34," + " p, %36, %37;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + " %18," + " %19," + " %20, %21," + " p, %23, %24;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + "{%18, %19, %20, %21}," + " %22," + " %23, %24," + " p, %26, %27;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " %38, %39," + " p, %41, %42;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " %41, %42," + " p, %44, %45;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " %22, %23," + " p, %25, %26;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " %25, %26," + " p, %28, %29;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + " %22," + " %23," + " %24, %25," + " p, %27, %28;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + "{%22, %23, %24, %25}," + " %26," + " %27, %28," + " p, %30, %31;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " %46, %47," + " p, %49, %50;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " %49, %50," + " p, %52, %53;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + " %26," + " %27," + " %28, %29," + " p, %31, %32;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + "{%26, %27, %28, %29}," + " %30," + " %31, %32," + " p, %34, %35;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " %54, %55," + " p, %57, %58;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " %57, %58," + " p, %60, %61;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " %30, %31," + " p, %33, %34;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " %33, %34," + " p, %36, %37;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + " %30," + " %31," + " %32, %33," + " p, %35, %36;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + "{%30, %31, %32, %33}," + " %34," + " %35, %36," + " p, %38, %39;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " %62, %63," + " p, %65, %66;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " %65, %66," + " p, %68, %69;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + " %34," + " %35," + " %36, %37," + " p, %39, %40;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + "{%34, %35, %36, %37}," + " %38," + " %39, %40," + " p, %42, %43;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %72, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + " %68," + " %69," + " %70, %71," + " p, %73, %74;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %75, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + "{%68, %69, %70, %71}," + " %72," + " %73, %74," + " p, %76, %77;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " %38, %39," + " p, %41, %42;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " %41, %42," + " p, %44, %45;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p, %77, %78;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p, %80, %81;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + " %38," + " %39," + " %40, %41," + " p, %43, %44;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + "{%38, %39, %40, %41}," + " %42," + " %43, %44," + " p, %46, %47;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %80, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + " %76," + " %77," + " %78, %79," + " p, %81, %82;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %83, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + "{%76, %77, %78, %79}," + " %80," + " %81, %82," + " p, %84, %85;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p, %85, %86;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p, %88, %89;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + " %42," + " %43," + " %44, %45," + " p, %47, %48;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + "{%42, %43, %44, %45}," + " %46," + " %47, %48," + " p, %50, %51;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %88, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + " %84," + " %85," + " %86, %87," + " p, %89, %90;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %91, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + "{%84, %85, %86, %87}," + " %88," + " %89, %90," + " p, %92, %93;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " %46, %47," + " p, %49, %50;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " %49, %50," + " p, %52, %53;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p, %93, %94;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p, %96, %97;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + " %46," + " %47," + " %48, %49," + " p, %51, %52;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + "{%46, %47, %48, %49}," + " %50," + " %51, %52," + " p, %54, %55;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %96, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + " %92," + " %93," + " %94, %95," + " p, %97, %98;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %99, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + "{%92, %93, %94, %95}," + " %96," + " %97, %98," + " p, %100, %101;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + " %50," + " %51," + " %52, %53," + " p, %55, %56;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + "{%50, %51, %52, %53}," + " %54," + " %55, %56," + " p, %58, %59;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %104, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + " %100," + " %101," + " %102, %103," + " p, %105, %106;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %107, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + "{%100, %101, %102, %103}," + " %104," + " %105, %106," + " p, %108, %109;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " %54, %55," + " p, %57, %58;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " %57, %58," + " p, %60, %61;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p, %109, %110;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p, %112, %113;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + " %54," + " %55," + " %56, %57," + " p, %59, %60;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + "{%54, %55, %56, %57}," + " %58," + " %59, %60," + " p, %62, %63;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %112, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + " %108," + " %109," + " %110, %111," + " p, %113, %114;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %115, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + "{%108, %109, %110, %111}," + " %112," + " %113, %114," + " p, %116, %117;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p, %117, %118;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p, %120, %121;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + " %58," + " %59," + " %60, %61," + " p, %63, %64;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + "{%58, %59, %60, %61}," + " %62," + " %63, %64," + " p, %66, %67;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %120, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + " %116," + " %117," + " %118, %119," + " p, %121, %122;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %123, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + "{%116, %117, %118, %119}," + " %120," + " %121, %122," + " p, %124, %125;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " %62, %63," + " p, %65, %66;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " %65, %66," + " p, %68, %69;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p, %125, %126;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p, %128, %129;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x64_F16E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + " %62," + " %63," + " %64, %65," + " p, %67, %68;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x64_F16E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x64 TN F16+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x64_F16E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k64.f16.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + "{%62, %63, %64, %65}," + " %66," + " %67, %68," + " p, %70, %71;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x64_F16E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x64_F32E4M3E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %128, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + " %124," + " %125," + " %126, %127," + " p, %129, %130;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x64_F32E4M3E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x64 TN F32+=E4M3*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x64_F32E4M3E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %131, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k64.f32.e4m3.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + "{%124, %125, %126, %127}," + " %128," + " %129, %130," + " p, %132, %133;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x64_F32E4M3E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5}," + " %6," + " %7," + " %8, %9," + " p, %11, %12;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5}," + "{%6, %7, %8, %9}," + " %10," + " %11, %12," + " p, %14, %15;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p, %17, %18;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p, %20, %21;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + " %10," + " %11," + " %12, %13," + " p, %15, %16;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + "{%10, %11, %12, %13}," + " %14," + " %15, %16," + " p, %18, %19;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " %22, %23," + " p, %25, %26;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " %25, %26," + " p, %28, %29;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p, %17, %18;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p, %20, %21;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + " %14," + " %15," + " %16, %17," + " p, %19, %20;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + "{%14, %15, %16, %17}," + " %18," + " %19, %20," + " p, %22, %23;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " %30, %31," + " p, %33, %34;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " %33, %34," + " p, %36, %37;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + " %18," + " %19," + " %20, %21," + " p, %23, %24;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + "{%18, %19, %20, %21}," + " %22," + " %23, %24," + " p, %26, %27;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " %38, %39," + " p, %41, %42;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " %41, %42," + " p, %44, %45;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " %22, %23," + " p, %25, %26;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " %25, %26," + " p, %28, %29;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + " %22," + " %23," + " %24, %25," + " p, %27, %28;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + "{%22, %23, %24, %25}," + " %26," + " %27, %28," + " p, %30, %31;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " %46, %47," + " p, %49, %50;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " %49, %50," + " p, %52, %53;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + " %26," + " %27," + " %28, %29," + " p, %31, %32;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + "{%26, %27, %28, %29}," + " %30," + " %31, %32," + " p, %34, %35;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " %54, %55," + " p, %57, %58;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " %57, %58," + " p, %60, %61;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " %30, %31," + " p, %33, %34;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " %33, %34," + " p, %36, %37;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + " %30," + " %31," + " %32, %33," + " p, %35, %36;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + "{%30, %31, %32, %33}," + " %34," + " %35, %36," + " p, %38, %39;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " %62, %63," + " p, %65, %66;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " %65, %66," + " p, %68, %69;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + " %34," + " %35," + " %36, %37," + " p, %39, %40;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + "{%34, %35, %36, %37}," + " %38," + " %39, %40," + " p, %42, %43;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %72, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + " %68," + " %69," + " %70, %71," + " p, %73, %74;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %75, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + "{%68, %69, %70, %71}," + " %72," + " %73, %74," + " p, %76, %77;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " %38, %39," + " p, %41, %42;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " %41, %42," + " p, %44, %45;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p, %77, %78;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p, %80, %81;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + " %38," + " %39," + " %40, %41," + " p, %43, %44;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + "{%38, %39, %40, %41}," + " %42," + " %43, %44," + " p, %46, %47;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %80, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + " %76," + " %77," + " %78, %79," + " p, %81, %82;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %83, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + "{%76, %77, %78, %79}," + " %80," + " %81, %82," + " p, %84, %85;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p, %85, %86;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p, %88, %89;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + " %42," + " %43," + " %44, %45," + " p, %47, %48;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + "{%42, %43, %44, %45}," + " %46," + " %47, %48," + " p, %50, %51;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %88, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + " %84," + " %85," + " %86, %87," + " p, %89, %90;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %91, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + "{%84, %85, %86, %87}," + " %88," + " %89, %90," + " p, %92, %93;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " %46, %47," + " p, %49, %50;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " %49, %50," + " p, %52, %53;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p, %93, %94;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p, %96, %97;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + " %46," + " %47," + " %48, %49," + " p, %51, %52;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + "{%46, %47, %48, %49}," + " %50," + " %51, %52," + " p, %54, %55;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %96, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + " %92," + " %93," + " %94, %95," + " p, %97, %98;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %99, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + "{%92, %93, %94, %95}," + " %96," + " %97, %98," + " p, %100, %101;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + " %50," + " %51," + " %52, %53," + " p, %55, %56;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + "{%50, %51, %52, %53}," + " %54," + " %55, %56," + " p, %58, %59;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %104, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + " %100," + " %101," + " %102, %103," + " p, %105, %106;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %107, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + "{%100, %101, %102, %103}," + " %104," + " %105, %106," + " p, %108, %109;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " %54, %55," + " p, %57, %58;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " %57, %58," + " p, %60, %61;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p, %109, %110;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p, %112, %113;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + " %54," + " %55," + " %56, %57," + " p, %59, %60;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + "{%54, %55, %56, %57}," + " %58," + " %59, %60," + " p, %62, %63;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %112, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + " %108," + " %109," + " %110, %111," + " p, %113, %114;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %115, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + "{%108, %109, %110, %111}," + " %112," + " %113, %114," + " p, %116, %117;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p, %117, %118;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p, %120, %121;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + " %58," + " %59," + " %60, %61," + " p, %63, %64;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + "{%58, %59, %60, %61}," + " %62," + " %63, %64," + " p, %66, %67;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %120, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + " %116," + " %117," + " %118, %119," + " p, %121, %122;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %123, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + "{%116, %117, %118, %119}," + " %120," + " %121, %122," + " p, %124, %125;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " %62, %63," + " p, %65, %66;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " %65, %66," + " p, %68, %69;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p, %125, %126;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p, %128, %129;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x64_F16E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + " %62," + " %63," + " %64, %65," + " p, %67, %68;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x64_F16E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x64 TN F16+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x64_F16E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k64.f16.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + "{%62, %63, %64, %65}," + " %66," + " %67, %68," + " p, %70, %71;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x64_F16E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x64_F32E4M3E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %128, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + " %124," + " %125," + " %126, %127," + " p, %129, %130;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x64_F32E4M3E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x64 TN F32+=E4M3*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x64_F32E4M3E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %131, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k64.f32.e4m3.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + "{%124, %125, %126, %127}," + " %128," + " %129, %130," + " p, %132, %133;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x64_F32E4M3E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5}," + " %6," + " %7," + " %8, %9," + " p, %11, %12;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5}," + "{%6, %7, %8, %9}," + " %10," + " %11, %12," + " p, %14, %15;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p, %17, %18;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p, %20, %21;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + " %10," + " %11," + " %12, %13," + " p, %15, %16;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + "{%10, %11, %12, %13}," + " %14," + " %15, %16," + " p, %18, %19;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " %22, %23," + " p, %25, %26;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " %25, %26," + " p, %28, %29;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p, %17, %18;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p, %20, %21;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + " %14," + " %15," + " %16, %17," + " p, %19, %20;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + "{%14, %15, %16, %17}," + " %18," + " %19, %20," + " p, %22, %23;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " %30, %31," + " p, %33, %34;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " %33, %34," + " p, %36, %37;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + " %18," + " %19," + " %20, %21," + " p, %23, %24;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + "{%18, %19, %20, %21}," + " %22," + " %23, %24," + " p, %26, %27;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " %38, %39," + " p, %41, %42;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " %41, %42," + " p, %44, %45;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " %22, %23," + " p, %25, %26;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " %25, %26," + " p, %28, %29;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + " %22," + " %23," + " %24, %25," + " p, %27, %28;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + "{%22, %23, %24, %25}," + " %26," + " %27, %28," + " p, %30, %31;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " %46, %47," + " p, %49, %50;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " %49, %50," + " p, %52, %53;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + " %26," + " %27," + " %28, %29," + " p, %31, %32;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + "{%26, %27, %28, %29}," + " %30," + " %31, %32," + " p, %34, %35;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " %54, %55," + " p, %57, %58;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " %57, %58," + " p, %60, %61;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " %30, %31," + " p, %33, %34;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " %33, %34," + " p, %36, %37;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + " %30," + " %31," + " %32, %33," + " p, %35, %36;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + "{%30, %31, %32, %33}," + " %34," + " %35, %36," + " p, %38, %39;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " %62, %63," + " p, %65, %66;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " %65, %66," + " p, %68, %69;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + " %34," + " %35," + " %36, %37," + " p, %39, %40;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + "{%34, %35, %36, %37}," + " %38," + " %39, %40," + " p, %42, %43;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %72, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + " %68," + " %69," + " %70, %71," + " p, %73, %74;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %75, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + "{%68, %69, %70, %71}," + " %72," + " %73, %74," + " p, %76, %77;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " %38, %39," + " p, %41, %42;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " %41, %42," + " p, %44, %45;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p, %77, %78;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p, %80, %81;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + " %38," + " %39," + " %40, %41," + " p, %43, %44;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + "{%38, %39, %40, %41}," + " %42," + " %43, %44," + " p, %46, %47;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %80, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + " %76," + " %77," + " %78, %79," + " p, %81, %82;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %83, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + "{%76, %77, %78, %79}," + " %80," + " %81, %82," + " p, %84, %85;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p, %85, %86;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p, %88, %89;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + " %42," + " %43," + " %44, %45," + " p, %47, %48;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + "{%42, %43, %44, %45}," + " %46," + " %47, %48," + " p, %50, %51;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %88, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + " %84," + " %85," + " %86, %87," + " p, %89, %90;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %91, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + "{%84, %85, %86, %87}," + " %88," + " %89, %90," + " p, %92, %93;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " %46, %47," + " p, %49, %50;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " %49, %50," + " p, %52, %53;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p, %93, %94;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p, %96, %97;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + " %46," + " %47," + " %48, %49," + " p, %51, %52;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + "{%46, %47, %48, %49}," + " %50," + " %51, %52," + " p, %54, %55;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %96, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + " %92," + " %93," + " %94, %95," + " p, %97, %98;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %99, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + "{%92, %93, %94, %95}," + " %96," + " %97, %98," + " p, %100, %101;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + " %50," + " %51," + " %52, %53," + " p, %55, %56;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + "{%50, %51, %52, %53}," + " %54," + " %55, %56," + " p, %58, %59;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %104, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + " %100," + " %101," + " %102, %103," + " p, %105, %106;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %107, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + "{%100, %101, %102, %103}," + " %104," + " %105, %106," + " p, %108, %109;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " %54, %55," + " p, %57, %58;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " %57, %58," + " p, %60, %61;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p, %109, %110;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p, %112, %113;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + " %54," + " %55," + " %56, %57," + " p, %59, %60;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + "{%54, %55, %56, %57}," + " %58," + " %59, %60," + " p, %62, %63;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %112, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + " %108," + " %109," + " %110, %111," + " p, %113, %114;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %115, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + "{%108, %109, %110, %111}," + " %112," + " %113, %114," + " p, %116, %117;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p, %117, %118;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p, %120, %121;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + " %58," + " %59," + " %60, %61," + " p, %63, %64;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + "{%58, %59, %60, %61}," + " %62," + " %63, %64," + " p, %66, %67;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %120, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + " %116," + " %117," + " %118, %119," + " p, %121, %122;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %123, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + "{%116, %117, %118, %119}," + " %120," + " %121, %122," + " p, %124, %125;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " %62, %63," + " p, %65, %66;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " %65, %66," + " p, %68, %69;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p, %125, %126;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p, %128, %129;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x64_F16E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + " %62," + " %63," + " %64, %65," + " p, %67, %68;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x64_F16E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x64 TN F16+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x64_F16E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k64.f16.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + "{%62, %63, %64, %65}," + " %66," + " %67, %68," + " p, %70, %71;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x64_F16E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x64_F32E5M2E4M3_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %128, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + " %124," + " %125," + " %126, %127," + " p, %129, %130;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x64_F32E5M2E4M3_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x64 TN F32+=E5M2*E4M3 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x64_F32E5M2E4M3_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %131, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k64.f32.e5m2.e4m3 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + "{%124, %125, %126, %127}," + " %128," + " %129, %130," + " p, %132, %133;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x64_F32E5M2E4M3_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %10, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5}," + " %6," + " %7," + " %8, %9," + " p, %11, %12;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[6]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a0, uint32_t const& a1, uint32_t const& a2, uint32_t const& a3, + uint64_t const& desc_b, + uint32_t & d0, uint32_t & d1, uint32_t & d2, uint32_t & d3, + uint32_t & d4, uint32_t & d5, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %13, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5}," + "{%6, %7, %8, %9}," + " %10," + " %11, %12," + " p, %14, %15;\n" + "}\n" + : "+r"(d0), "+r"(d1), "+r"(d2), "+r"(d3), + "+r"(d4), "+r"(d5) + : "r"(a0), "r"(a1), "r"(a2), "r"(a3), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p, %17, %18;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x24x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x24x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n24k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p, %20, %21;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x24x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %14, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + " %10," + " %11," + " %12, %13," + " p, %15, %16;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[10]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %17, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9}," + "{%10, %11, %12, %13}," + " %14," + " %15, %16," + " p, %18, %19;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " %22, %23," + " p, %25, %26;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x40x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x40x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n40k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " %25, %26," + " p, %28, %29;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x40x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %16, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + " %12," + " %13," + " %14, %15," + " p, %17, %18;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[12]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %19, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11}," + "{%12, %13, %14, %15}," + " %16," + " %17, %18," + " p, %20, %21;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %28, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + " %24," + " %25," + " %26, %27," + " p, %29, %30;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x48x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x48x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[24]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %31, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n48k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23}," + "{%24, %25, %26, %27}," + " %28," + " %29, %30," + " p, %32, %33;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x48x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %18, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + " %14," + " %15," + " %16, %17," + " p, %19, %20;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[14]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %21, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13}," + "{%14, %15, %16, %17}," + " %18," + " %19, %20," + " p, %22, %23;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " %30, %31," + " p, %33, %34;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x56x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x56x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n56k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " %33, %34," + " p, %36, %37;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x56x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %22, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + " %18," + " %19," + " %20, %21," + " p, %23, %24;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[18]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %25, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17}," + "{%18, %19, %20, %21}," + " %22," + " %23, %24," + " p, %26, %27;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " %38, %39," + " p, %41, %42;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x72x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x72x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n72k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " %41, %42," + " p, %44, %45;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x72x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %24, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + " %20," + " %21," + " %22, %23," + " p, %25, %26;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[20]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %27, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19}," + "{%20, %21, %22, %23}," + " %24," + " %25, %26," + " p, %28, %29;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x80x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x80x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n80k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x80x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %26, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + " %22," + " %23," + " %24, %25," + " p, %27, %28;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[22]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %29, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21}," + "{%22, %23, %24, %25}," + " %26," + " %27, %28," + " p, %30, %31;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " %46, %47," + " p, %49, %50;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x88x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x88x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n88k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " %49, %50," + " p, %52, %53;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x88x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %30, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + " %26," + " %27," + " %28, %29," + " p, %31, %32;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[26]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %33, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25}," + "{%26, %27, %28, %29}," + " %30," + " %31, %32," + " p, %34, %35;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " %54, %55," + " p, %57, %58;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x104x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x104x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n104k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " %57, %58," + " p, %60, %61;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x104x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %32, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + " %28," + " %29," + " %30, %31," + " p, %33, %34;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[28]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %35, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27}," + "{%28, %29, %30, %31}," + " %32," + " %33, %34," + " p, %36, %37;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x112x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x112x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n112k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x112x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %34, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + " %30," + " %31," + " %32, %33," + " p, %35, %36;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[30]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %37, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29}," + "{%30, %31, %32, %33}," + " %34," + " %35, %36," + " p, %38, %39;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " %62, %63," + " p, %65, %66;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x120x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x120x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n120k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " %65, %66," + " p, %68, %69;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x120x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %38, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + " %34," + " %35," + " %36, %37," + " p, %39, %40;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[34]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %41, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33}," + "{%34, %35, %36, %37}," + " %38," + " %39, %40," + " p, %42, %43;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %72, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + " %68," + " %69," + " %70, %71," + " p, %73, %74;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x136x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x136x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[68]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %75, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n136k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67}," + "{%68, %69, %70, %71}," + " %72," + " %73, %74," + " p, %76, %77;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x136x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %40, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + " %36," + " %37," + " %38, %39," + " p, %41, %42;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[36]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %43, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35}," + "{%36, %37, %38, %39}," + " %40," + " %41, %42," + " p, %44, %45;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %76, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + " %72," + " %73," + " %74, %75," + " p, %77, %78;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x144x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x144x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[72]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %79, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n144k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71}," + "{%72, %73, %74, %75}," + " %76," + " %77, %78," + " p, %80, %81;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x144x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %42, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + " %38," + " %39," + " %40, %41," + " p, %43, %44;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[38]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %45, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37}," + "{%38, %39, %40, %41}," + " %42," + " %43, %44," + " p, %46, %47;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %80, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + " %76," + " %77," + " %78, %79," + " p, %81, %82;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x152x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x152x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[76]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %83, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n152k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75}," + "{%76, %77, %78, %79}," + " %80," + " %81, %82," + " p, %84, %85;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x152x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %44, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + " %40," + " %41," + " %42, %43," + " p, %45, %46;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[40]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %47, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39}," + "{%40, %41, %42, %43}," + " %44," + " %45, %46," + " p, %48, %49;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %84, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + " %80," + " %81," + " %82, %83," + " p, %85, %86;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x160x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x160x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[80]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %87, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n160k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79}," + "{%80, %81, %82, %83}," + " %84," + " %85, %86," + " p, %88, %89;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x160x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %46, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + " %42," + " %43," + " %44, %45," + " p, %47, %48;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[42]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %49, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41}," + "{%42, %43, %44, %45}," + " %46," + " %47, %48," + " p, %50, %51;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %88, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + " %84," + " %85," + " %86, %87," + " p, %89, %90;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x168x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x168x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[84]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %91, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n168k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83}," + "{%84, %85, %86, %87}," + " %88," + " %89, %90," + " p, %92, %93;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x168x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %48, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + " %44," + " %45," + " %46, %47," + " p, %49, %50;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[44]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %51, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43}," + "{%44, %45, %46, %47}," + " %48," + " %49, %50," + " p, %52, %53;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %92, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + " %88," + " %89," + " %90, %91," + " p, %93, %94;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x176x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x176x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[88]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %95, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n176k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87}," + "{%88, %89, %90, %91}," + " %92," + " %93, %94," + " p, %96, %97;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x176x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %50, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + " %46," + " %47," + " %48, %49," + " p, %51, %52;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[46]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %53, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45}," + "{%46, %47, %48, %49}," + " %50," + " %51, %52," + " p, %54, %55;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %96, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + " %92," + " %93," + " %94, %95," + " p, %97, %98;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x184x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x184x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[92]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + float & d00, float & d01, float & d02, float & d03, + float & d04, float & d05, float & d06, float & d07, + float & d08, float & d09, float & d10, float & d11, + float & d12, float & d13, float & d14, float & d15, + float & d16, float & d17, float & d18, float & d19, + float & d20, float & d21, float & d22, float & d23, + float & d24, float & d25, float & d26, float & d27, + float & d28, float & d29, float & d30, float & d31, + float & d32, float & d33, float & d34, float & d35, + float & d36, float & d37, float & d38, float & d39, + float & d40, float & d41, float & d42, float & d43, + float & d44, float & d45, float & d46, float & d47, + float & d48, float & d49, float & d50, float & d51, + float & d52, float & d53, float & d54, float & d55, + float & d56, float & d57, float & d58, float & d59, + float & d60, float & d61, float & d62, float & d63, + float & d64, float & d65, float & d66, float & d67, + float & d68, float & d69, float & d70, float & d71, + float & d72, float & d73, float & d74, float & d75, + float & d76, float & d77, float & d78, float & d79, + float & d80, float & d81, float & d82, float & d83, + float & d84, float & d85, float & d86, float & d87, + float & d88, float & d89, float & d90, float & d91, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %99, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n184k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91}," + "{%92, %93, %94, %95}," + " %96," + " %97, %98," + " p, %100, %101;\n" + "}\n" + : "+f"(d00), "+f"(d01), "+f"(d02), "+f"(d03), + "+f"(d04), "+f"(d05), "+f"(d06), "+f"(d07), + "+f"(d08), "+f"(d09), "+f"(d10), "+f"(d11), + "+f"(d12), "+f"(d13), "+f"(d14), "+f"(d15), + "+f"(d16), "+f"(d17), "+f"(d18), "+f"(d19), + "+f"(d20), "+f"(d21), "+f"(d22), "+f"(d23), + "+f"(d24), "+f"(d25), "+f"(d26), "+f"(d27), + "+f"(d28), "+f"(d29), "+f"(d30), "+f"(d31), + "+f"(d32), "+f"(d33), "+f"(d34), "+f"(d35), + "+f"(d36), "+f"(d37), "+f"(d38), "+f"(d39), + "+f"(d40), "+f"(d41), "+f"(d42), "+f"(d43), + "+f"(d44), "+f"(d45), "+f"(d46), "+f"(d47), + "+f"(d48), "+f"(d49), "+f"(d50), "+f"(d51), + "+f"(d52), "+f"(d53), "+f"(d54), "+f"(d55), + "+f"(d56), "+f"(d57), "+f"(d58), "+f"(d59), + "+f"(d60), "+f"(d61), "+f"(d62), "+f"(d63), + "+f"(d64), "+f"(d65), "+f"(d66), "+f"(d67), + "+f"(d68), "+f"(d69), "+f"(d70), "+f"(d71), + "+f"(d72), "+f"(d73), "+f"(d74), "+f"(d75), + "+f"(d76), "+f"(d77), "+f"(d78), "+f"(d79), + "+f"(d80), "+f"(d81), "+f"(d82), "+f"(d83), + "+f"(d84), "+f"(d85), "+f"(d86), "+f"(d87), + "+f"(d88), "+f"(d89), "+f"(d90), "+f"(d91) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x184x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %54, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + " %50," + " %51," + " %52, %53," + " p, %55, %56;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[50]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %57, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49}," + "{%50, %51, %52, %53}," + " %54," + " %55, %56," + " p, %58, %59;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %104, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + " %100," + " %101," + " %102, %103," + " p, %105, %106;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x200x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x200x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[100]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %107, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n200k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99}," + "{%100, %101, %102, %103}," + " %104," + " %105, %106," + " p, %108, %109;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x200x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %56, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + " %52," + " %53," + " %54, %55," + " p, %57, %58;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[52]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %59, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51}," + "{%52, %53, %54, %55}," + " %56," + " %57, %58," + " p, %60, %61;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %108, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + " %104," + " %105," + " %106, %107," + " p, %109, %110;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x208x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x208x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[104]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %111, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n208k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103}," + "{%104, %105, %106, %107}," + " %108," + " %109, %110," + " p, %112, %113;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x208x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %58, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + " %54," + " %55," + " %56, %57," + " p, %59, %60;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[54]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %61, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53}," + "{%54, %55, %56, %57}," + " %58," + " %59, %60," + " p, %62, %63;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %112, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + " %108," + " %109," + " %110, %111," + " p, %113, %114;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x216x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x216x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[108]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %115, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n216k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107}," + "{%108, %109, %110, %111}," + " %112," + " %113, %114," + " p, %116, %117;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x216x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %60, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + " %56," + " %57," + " %58, %59," + " p, %61, %62;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[56]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %63, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55}," + "{%56, %57, %58, %59}," + " %60," + " %61, %62," + " p, %64, %65;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %116, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + " %112," + " %113," + " %114, %115," + " p, %117, %118;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x224x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x224x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[112]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %119, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n224k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111}," + "{%112, %113, %114, %115}," + " %116," + " %117, %118," + " p, %120, %121;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x224x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %62, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + " %58," + " %59," + " %60, %61," + " p, %63, %64;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[58]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %65, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57}," + "{%58, %59, %60, %61}," + " %62," + " %63, %64," + " p, %66, %67;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %120, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + " %116," + " %117," + " %118, %119," + " p, %121, %122;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x232x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x232x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[116]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %123, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n232k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115}," + "{%116, %117, %118, %119}," + " %120," + " %121, %122," + " p, %124, %125;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x232x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %64, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + " %60," + " %61," + " %62, %63," + " p, %65, %66;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[60]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %67, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59}," + "{%60, %61, %62, %63}," + " %64," + " %65, %66," + " p, %68, %69;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %124, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + " %120," + " %121," + " %122, %123," + " p, %125, %126;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x240x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x240x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[120]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %127, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n240k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119}," + "{%120, %121, %122, %123}," + " %124," + " %125, %126," + " p, %128, %129;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x240x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x64_F16E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %66, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + " %62," + " %63," + " %64, %65," + " p, %67, %68;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x64_F16E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x64 TN F16+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x64_F16E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = uint32_t[62]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a00, uint32_t const& a01, uint32_t const& a02, uint32_t const& a03, + uint64_t const& desc_b, + uint32_t & d00, uint32_t & d01, uint32_t & d02, uint32_t & d03, + uint32_t & d04, uint32_t & d05, uint32_t & d06, uint32_t & d07, + uint32_t & d08, uint32_t & d09, uint32_t & d10, uint32_t & d11, + uint32_t & d12, uint32_t & d13, uint32_t & d14, uint32_t & d15, + uint32_t & d16, uint32_t & d17, uint32_t & d18, uint32_t & d19, + uint32_t & d20, uint32_t & d21, uint32_t & d22, uint32_t & d23, + uint32_t & d24, uint32_t & d25, uint32_t & d26, uint32_t & d27, + uint32_t & d28, uint32_t & d29, uint32_t & d30, uint32_t & d31, + uint32_t & d32, uint32_t & d33, uint32_t & d34, uint32_t & d35, + uint32_t & d36, uint32_t & d37, uint32_t & d38, uint32_t & d39, + uint32_t & d40, uint32_t & d41, uint32_t & d42, uint32_t & d43, + uint32_t & d44, uint32_t & d45, uint32_t & d46, uint32_t & d47, + uint32_t & d48, uint32_t & d49, uint32_t & d50, uint32_t & d51, + uint32_t & d52, uint32_t & d53, uint32_t & d54, uint32_t & d55, + uint32_t & d56, uint32_t & d57, uint32_t & d58, uint32_t & d59, + uint32_t & d60, uint32_t & d61, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %69, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k64.f16.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61}," + "{%62, %63, %64, %65}," + " %66," + " %67, %68," + " p, %70, %71;\n" + "}\n" + : "+r"(d00), "+r"(d01), "+r"(d02), "+r"(d03), + "+r"(d04), "+r"(d05), "+r"(d06), "+r"(d07), + "+r"(d08), "+r"(d09), "+r"(d10), "+r"(d11), + "+r"(d12), "+r"(d13), "+r"(d14), "+r"(d15), + "+r"(d16), "+r"(d17), "+r"(d18), "+r"(d19), + "+r"(d20), "+r"(d21), "+r"(d22), "+r"(d23), + "+r"(d24), "+r"(d25), "+r"(d26), "+r"(d27), + "+r"(d28), "+r"(d29), "+r"(d30), "+r"(d31), + "+r"(d32), "+r"(d33), "+r"(d34), "+r"(d35), + "+r"(d36), "+r"(d37), "+r"(d38), "+r"(d39), + "+r"(d40), "+r"(d41), "+r"(d42), "+r"(d43), + "+r"(d44), "+r"(d45), "+r"(d46), "+r"(d47), + "+r"(d48), "+r"(d49), "+r"(d50), "+r"(d51), + "+r"(d52), "+r"(d53), "+r"(d54), "+r"(d55), + "+r"(d56), "+r"(d57), "+r"(d58), "+r"(d59), + "+r"(d60), "+r"(d61) + : "r"(a00), "r"(a01), "r"(a02), "r"(a03), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x64_F16E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x64_F32E5M2E5M2_SS_TN +{ + using DRegisters = void; + using ARegisters = uint64_t[1]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint64_t const& desc_a, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_smem_smem(__LINE__, desc_a, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %128, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + " %124," + " %125," + " %126, %127," + " p, %129, %130;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "l"(desc_a), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x64_F32E5M2E5M2_SS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SPARSE GMMA 64x248x64 TN F32+=E5M2*E5M2 +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One, + GMMA::SparseSel spsel = GMMA::SparseSel::Zero +> +struct GMMA_64x248x64_F32E5M2E5M2_RS_TN +{ + using DRegisters = void; + using ARegisters = uint32_t[4]; + using ERegisters = uint32_t[1]; + using BRegisters = uint64_t[1]; + using CRegisters = float[124]; + + CUTE_HOST_DEVICE static void + fma(uint32_t const& a000, uint32_t const& a001, uint32_t const& a002, uint32_t const& a003, + uint64_t const& desc_b, + float & d000, float & d001, float & d002, float & d003, + float & d004, float & d005, float & d006, float & d007, + float & d008, float & d009, float & d010, float & d011, + float & d012, float & d013, float & d014, float & d015, + float & d016, float & d017, float & d018, float & d019, + float & d020, float & d021, float & d022, float & d023, + float & d024, float & d025, float & d026, float & d027, + float & d028, float & d029, float & d030, float & d031, + float & d032, float & d033, float & d034, float & d035, + float & d036, float & d037, float & d038, float & d039, + float & d040, float & d041, float & d042, float & d043, + float & d044, float & d045, float & d046, float & d047, + float & d048, float & d049, float & d050, float & d051, + float & d052, float & d053, float & d054, float & d055, + float & d056, float & d057, float & d058, float & d059, + float & d060, float & d061, float & d062, float & d063, + float & d064, float & d065, float & d066, float & d067, + float & d068, float & d069, float & d070, float & d071, + float & d072, float & d073, float & d074, float & d075, + float & d076, float & d077, float & d078, float & d079, + float & d080, float & d081, float & d082, float & d083, + float & d084, float & d085, float & d086, float & d087, + float & d088, float & d089, float & d090, float & d091, + float & d092, float & d093, float & d094, float & d095, + float & d096, float & d097, float & d098, float & d099, + float & d100, float & d101, float & d102, float & d103, + float & d104, float & d105, float & d106, float & d107, + float & d108, float & d109, float & d110, float & d111, + float & d112, float & d113, float & d114, float & d115, + float & d116, float & d117, float & d118, float & d119, + float & d120, float & d121, float & d122, float & d123, + uint32_t const& e, + GMMA::ScaleOut const scale_D = GMMA::ScaleOut::One) + { +#if defined(CUTE_ARCH_MMA_SM90A_ENABLED) + cutlass::arch::synclog_emit_wgmma_reg_smem(__LINE__, desc_b); + asm volatile( + "{\n" + ".reg .pred p;\n" + "setp.ne.b32 p, %131, 0;\n" + "wgmma.mma_async.sp.sync.aligned.m64n248k64.f32.e5m2.e5m2 " + "{%0, %1, %2, %3, %4, %5, %6, %7, " + " %8, %9, %10, %11, %12, %13, %14, %15, " + " %16, %17, %18, %19, %20, %21, %22, %23, " + " %24, %25, %26, %27, %28, %29, %30, %31, " + " %32, %33, %34, %35, %36, %37, %38, %39, " + " %40, %41, %42, %43, %44, %45, %46, %47, " + " %48, %49, %50, %51, %52, %53, %54, %55, " + " %56, %57, %58, %59, %60, %61, %62, %63, " + " %64, %65, %66, %67, %68, %69, %70, %71, " + " %72, %73, %74, %75, %76, %77, %78, %79, " + " %80, %81, %82, %83, %84, %85, %86, %87, " + " %88, %89, %90, %91, %92, %93, %94, %95, " + " %96, %97, %98, %99, %100, %101, %102, %103, " + " %104, %105, %106, %107, %108, %109, %110, %111, " + " %112, %113, %114, %115, %116, %117, %118, %119, " + " %120, %121, %122, %123}," + "{%124, %125, %126, %127}," + " %128," + " %129, %130," + " p, %132, %133;\n" + "}\n" + : "+f"(d000), "+f"(d001), "+f"(d002), "+f"(d003), + "+f"(d004), "+f"(d005), "+f"(d006), "+f"(d007), + "+f"(d008), "+f"(d009), "+f"(d010), "+f"(d011), + "+f"(d012), "+f"(d013), "+f"(d014), "+f"(d015), + "+f"(d016), "+f"(d017), "+f"(d018), "+f"(d019), + "+f"(d020), "+f"(d021), "+f"(d022), "+f"(d023), + "+f"(d024), "+f"(d025), "+f"(d026), "+f"(d027), + "+f"(d028), "+f"(d029), "+f"(d030), "+f"(d031), + "+f"(d032), "+f"(d033), "+f"(d034), "+f"(d035), + "+f"(d036), "+f"(d037), "+f"(d038), "+f"(d039), + "+f"(d040), "+f"(d041), "+f"(d042), "+f"(d043), + "+f"(d044), "+f"(d045), "+f"(d046), "+f"(d047), + "+f"(d048), "+f"(d049), "+f"(d050), "+f"(d051), + "+f"(d052), "+f"(d053), "+f"(d054), "+f"(d055), + "+f"(d056), "+f"(d057), "+f"(d058), "+f"(d059), + "+f"(d060), "+f"(d061), "+f"(d062), "+f"(d063), + "+f"(d064), "+f"(d065), "+f"(d066), "+f"(d067), + "+f"(d068), "+f"(d069), "+f"(d070), "+f"(d071), + "+f"(d072), "+f"(d073), "+f"(d074), "+f"(d075), + "+f"(d076), "+f"(d077), "+f"(d078), "+f"(d079), + "+f"(d080), "+f"(d081), "+f"(d082), "+f"(d083), + "+f"(d084), "+f"(d085), "+f"(d086), "+f"(d087), + "+f"(d088), "+f"(d089), "+f"(d090), "+f"(d091), + "+f"(d092), "+f"(d093), "+f"(d094), "+f"(d095), + "+f"(d096), "+f"(d097), "+f"(d098), "+f"(d099), + "+f"(d100), "+f"(d101), "+f"(d102), "+f"(d103), + "+f"(d104), "+f"(d105), "+f"(d106), "+f"(d107), + "+f"(d108), "+f"(d109), "+f"(d110), "+f"(d111), + "+f"(d112), "+f"(d113), "+f"(d114), "+f"(d115), + "+f"(d116), "+f"(d117), "+f"(d118), "+f"(d119), + "+f"(d120), "+f"(d121), "+f"(d122), "+f"(d123) + : "r"(a000), "r"(a001), "r"(a002), "r"(a003), + "l"(desc_b), + "r"(e), "n"(int32_t(spsel)), + "r"(int32_t(scale_D)), "n"(int32_t(scaleA)), "n"(int32_t(scaleB))); +#else + CUTE_INVALID_CONTROL_PATH("Attempting to use SM90::GMMA::SPARSE::GMMA_64x248x64_F32E5M2E5M2_RS_TN without CUTE_ARCH_MMA_SM90A_ENABLED"); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace SM90::GMMA::SPARSE + +} // namespace cute diff --git a/include/cute/arch/util.hpp b/include/cute/arch/util.hpp new file mode 100644 index 0000000000..3749a9c255 --- /dev/null +++ b/include/cute/arch/util.hpp @@ -0,0 +1,320 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include + +#if defined(__clang__) && defined(__CUDA__) + // __cvta_generic_to_shared was added in Clang 14: https://reviews.llvm.org/D111665 + #if __clang_major__ >= 14 + #define CUTE_CLANG_SUPPORTS_CVTA_GENERIC_TO_SHARED 1 + #endif + + // __nvvm_get_smem_pointer added in Clang 14: https://reviews.llvm.org/D111665 + // ... but will not work on Windows until Clang 15: https://reviews.llvm.org/D122897 + #if (!defined(_WIN32) && __clang_major__ >= 14) || __clang_major__ >= 15 + #define CUTE_CLANG_SUPPORTS_NVVM_GET_SMEM_POINTER 1 + #endif +#endif + +#if defined(__NVCC__) || defined(__CUDACC_RTC__) + // __cvta_generic_to_shared added in CUDA 11+ + #if __CUDACC_VER_MAJOR__ >= 11 + #define CUTE_NVCC_SUPPORTS_CVTA_GENERIC_TO_SHARED 1 + #endif + + // __nvvm_get_smem_pointer added in CUDA 10.2 + #if __CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2 + #define CUTE_NVCC_SUPPORTS_NVVM_GET_SMEM_POINTER 1 + #endif +#endif + +#if CUTE_NVCC_SUPPORTS_CVTA_GENERIC_TO_SHARED || CUTE_CLANG_SUPPORTS_CVTA_GENERIC_TO_SHARED + #define CUTE_CVTA_GENERIC_TO_SHARED_SUPPORTED 1 +#endif + +#if !defined(CUTE_CVTA_GENERIC_TO_SHARED_ACTIVATED) && CUTE_CVTA_GENERIC_TO_SHARED_SUPPORTED && defined(__CUDA_ARCH__) + #define CUTE_CVTA_GENERIC_TO_SHARED_ACTIVATED 1 +#endif + +#if CUTE_NVCC_SUPPORTS_NVVM_GET_SMEM_POINTER || CUTE_CLANG_SUPPORTS_NVVM_GET_SMEM_POINTER + #define CUTE_NVVM_GET_SMEM_POINTER_SUPPORTED 1 +#endif + +#if !defined(CUTE_NVVM_GET_SMEM_POINTER_ACTIVATED) && CUTE_NVVM_GET_SMEM_POINTER_SUPPORTED && defined(__CUDA_ARCH__) + #define CUTE_NVVM_GET_SMEM_POINTER_ACTIVATED 1 +#endif + +// Clang 14+ provides a declaration of __nvvm_get_smem_pointer, so we only need +// to provide one for NVCC +#if CUTE_NVCC_SUPPORTS_NVVM_GET_SMEM_POINTER + extern "C" { + // This NVVM intrinsic is subject to change in future versions of CUDA. + // Clients should not call it directly. + CUTE_DEVICE uint32_t __nvvm_get_smem_pointer(void*); + } +#endif + +namespace cute +{ + +/// CUTE helper to cast SMEM pointer to unsigned +CUTE_DEVICE +uint32_t +cast_smem_ptr_to_uint(void const* const ptr) +{ +// We prefer to use the new CVTA intrinsics if they are available, otherwise we will fall back to +// the previous internal intrinsics if they are available. +#if CUTE_CVTA_GENERIC_TO_SHARED_ACTIVATED + // + // This NVVM intrinsic converts an address in shared memory to a plain + // unsigned integer. This is necessary to pass to shared memory instructions + // in inline PTX. + // + // In CUDA 11 and beyond, this replaces __nvvm_get_smem_pointer() [only available in 10.2]. + // + //__device__ size_t __cvta_generic_to_shared(void* ptr); + + /// CUTE helper to get SMEM pointer + return static_cast(__cvta_generic_to_shared(ptr)); + +#elif CUTE_NVVM_GET_SMEM_POINTER_ACTIVATED + + return __nvvm_get_smem_pointer(ptr); + +#elif defined(__CUDA_ARCH__) + + uint32_t smem_ptr; + + asm( + "{ .reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, smem_ptr; }\n" + : "=r"(smem_ptr) : "l"(ptr)); + + return smem_ptr; + +#else + + + (void) ptr; + printf("ERROR: cast_smem_ptr_to_uint not supported but used.\n"); + return 0; + +#endif +} + +namespace detail { + +// +// Wrapper for MMAOp::fma +// + +template +struct CallFMA { + template + CUTE_HOST_DEVICE constexpr void + operator()(Args&&... args) const { + return MmaOp::fma(static_cast(args)...); + } +}; + +// +// Wrapper for CopyOp::copy +// + +template +struct CallCOPY { + template + CUTE_HOST_DEVICE constexpr void + operator()(Args&&... args) const { + return CopyOp::copy(static_cast(args)...); + } +}; + +// +// Utility for exploding pointers/arrays/tensors into functions +// + +template +CUTE_HOST_DEVICE constexpr +void +explode(Fn fn, + PtrA&& a, int_sequence) +{ + return fn(a[I]...); +} + +template +CUTE_HOST_DEVICE constexpr +void +explode(Fn fn, + PtrS&& s, int_sequence, + PtrD&& d, int_sequence) +{ + return fn(s[Is]..., d[Id]...); +} + +template +CUTE_HOST_DEVICE constexpr +void +explode(Fn fn, + PtrA&& a, int_sequence, + PtrB&& b, int_sequence, + PtrC&& c, int_sequence) +{ + return fn(a[Ia]..., b[Ib]..., c[Ic]...); +} + +template +CUTE_HOST_DEVICE constexpr +void +explode(Fn fn, + PtrD&& d, int_sequence, + PtrA&& a, int_sequence, + PtrB&& b, int_sequence, + PtrC&& c, int_sequence) +{ + return fn(d[Id]..., a[Ia]..., b[Ib]..., c[Ic]...); +} + +template +CUTE_HOST_DEVICE constexpr +void +explode(Fn fn, + PtrD&& d, int_sequence, + PtrA&& a, int_sequence, + PtrB&& b, int_sequence, + PtrC&& c, int_sequence, + PtrE&& e, int_sequence) +{ + return fn(d[Id]..., a[Ia]..., b[Ib]..., c[Ic]..., e[Ie]...); +} + +template +CUTE_HOST_DEVICE constexpr +void +explode(Fn fn, + PtrD&& d, int_sequence, + PtrA&& a, int_sequence, + PtrB&& b, int_sequence, + PtrC&& c, int_sequence, + PtrE&& e, int_sequence, + PtrF&& f, int_sequence) +{ + return fn(d[Id]..., a[Ia]..., b[Ib]..., c[Ic]..., e[Ie]..., f[If]...); +} + +template +CUTE_HOST_DEVICE constexpr +void +explode(Fn fn, + PtrD&& d, int_sequence, + PtrA&& a, int_sequence, + PtrB&& b, int_sequence, + PtrC&& c, int_sequence, + PtrE&& e, int_sequence, + PtrF&& f, int_sequence, + PtrG&& g, int_sequence) +{ + return fn(d[Id]..., a[Ia]..., b[Ib]..., c[Ic]..., e[Ie]..., f[If]..., g[Ig]...); +} + +// +// Utility for exploding tuples into functions +// + +template +CUTE_HOST_DEVICE constexpr +void +explode_tuple(Fn fn, + TupleA&& a, int_sequence) +{ + return fn(get(a)...); +} + +template +CUTE_HOST_DEVICE constexpr +void +explode_tuple(Fn fn, + TupleA&& a, int_sequence, + TupleB&& b, int_sequence) +{ + return fn(get(a)..., get(b)...); +} + +template +CUTE_HOST_DEVICE constexpr +void +explode_tuple(Fn fn, + TupleA&& a, int_sequence, + TupleB&& b, int_sequence, + TupleC&& c, int_sequence) +{ + return fn(get(a)..., get(b)..., get(c)...); +} + +} // end namespace detail + +} // end namespace cute diff --git a/include/cute/atom/copy_atom.hpp b/include/cute/atom/copy_atom.hpp new file mode 100644 index 0000000000..75b7aa4de6 --- /dev/null +++ b/include/cute/atom/copy_atom.hpp @@ -0,0 +1,764 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTE_HOST_DEVICE +#include // cute::Tensor +#include // cute::__CUTE_REQUIRES +#include // cute::is_tuple +#include // cute::is_constant, cute::is_integral +#include // cute::Copy_Traits +#include // cute::TiledMMA + +namespace cute +{ + +template +struct Copy_Atom; + +template +struct Copy_Atom : Copy_Atom, CopyInternalType> +{}; + +template +struct Copy_Atom, CopyInternalType> + : Copy_Traits +{ + using Traits = Copy_Traits; + + // Bit and Thr layouts from the Copy_Traits + using ThrID = typename Traits::ThrID; + using BitLayoutSrc = typename Traits::SrcLayout; + using BitLayoutDst = typename Traits::DstLayout; + using BitLayoutRef = typename Traits::RefLayout; + + using ValType = CopyInternalType; + + using ValLayoutSrc = decltype(recast_layout(BitLayoutSrc{})); + using ValLayoutDst = decltype(recast_layout(BitLayoutDst{})); + using ValLayoutRef = decltype(recast_layout(BitLayoutRef{})); + + CUTE_STATIC_ASSERT_V(size<0>(ValLayoutSrc{}) == size(ThrID{}), "CopyOperation is not valid for Src of ValType."); + CUTE_STATIC_ASSERT_V(size<0>(ValLayoutDst{}) == size(ThrID{}), "CopyOperation is not valid for Dst of ValType."); + CUTE_STATIC_ASSERT_V(size<0>(ValLayoutRef{}) == size(ThrID{}), "CopyOperation is not valid for Ref of ValType."); + + static constexpr int NumValSrc = size<1>(ValLayoutSrc{}); + static constexpr int NumValDst = size<1>(ValLayoutDst{}); + + // Additional Trait parameters/transformations + template + CUTE_HOST_DEVICE + auto + with(TraitsArgs&&... args) const { + auto traits = Traits::with(static_cast(args)...); + return Copy_Atom{traits}; + } + + // + // Tensor call interfaces + // + + // Check and call instruction, or recurse + template + CUTE_HOST_DEVICE + void + call(Tensor const& src, + Tensor & dst) const + { + static_assert(SLayout::rank == 1, "Expected rank-1 src tensor"); + static_assert(DLayout::rank == 1, "Expected rank-1 dst tensor"); + + if constexpr (is_constant::value || + is_constant::value) { + // Dispatch to unpack to execute instruction + return copy_unpack(static_cast(*this), src, dst); + } else if constexpr (is_tuple::value && + is_tuple::value) { + // If the size of the src/dst doesn't match the instruction, + // recurse this rank-1 layout by peeling off the mode + // ((A,B,C,...)) -> (A,B,C,...) + return copy(*this, tensor<0>(src), tensor<0>(dst)); + } else { + static_assert(dependent_false, + "CopyAtom: Src/Dst partitioning does not match the instruction requirement."); + } + } + + // Accept mutable temporaries + template + CUTE_HOST_DEVICE + void + call(Tensor const& src, + Tensor && dst) const + { + return call(src, dst); + } +}; + +// +// A tiling of copy atoms +// + +template +struct ThrCopy; + +template coord [Need not be 2D...] + class ShapeTiler_MN> // coord space +struct TiledCopy : Copy_Atom +{ + // Layout information from the CopyAtom + using AtomThrID = typename Copy_Atom::ThrID; // thrid -> thr_idx + using AtomLayoutSrc = typename Copy_Atom::ValLayoutSrc; // (thr,val) -> offset + using AtomLayoutDst = typename Copy_Atom::ValLayoutDst; // (thr,val) -> offset + using AtomLayoutRef = typename Copy_Atom::ValLayoutRef; // (thr,val) -> offset + + using AtomNumThr = decltype(size<0>(AtomLayoutRef{})); + using AtomNumVal = decltype(size<1>(AtomLayoutRef{})); + + // Layout information for the TiledCopy + using Tiler_MN = ShapeTiler_MN; + using TiledLayout_TV = LayoutCopy_TV; + using TiledNumThr = decltype(size<0>(TiledLayout_TV{})); + using TiledNumVal = decltype(size<1>(TiledLayout_TV{})); + + CUTE_STATIC_ASSERT_V(TiledNumThr{} % AtomNumThr{} == Int<0>{}, "TiledCopy uses too few thrs for selected CopyAtom"); + CUTE_STATIC_ASSERT_V(TiledNumVal{} % AtomNumVal{} == Int<0>{}, "TiledCopy uses too few vals for selected CopyAtom"); + + // Tile a tensor or a layout from shape + // (M,N,...) + // to shape + // ((ThrV,ThrX),FrgV,(RestM,RestN,...)) + // where + // ThrV: The threads local to a COPY_ATOM Src. + // ThrX: The threads tiled across COPY_ATOMs Src. + // FrgV: The values local to a COPY_ATOM Src. + // RestM: The values tiled in M. + // RestN: The values tiled in N. + template + CUTE_HOST_DEVICE constexpr static + auto + tidfrg_S(STensor&& stensor) + { + CUTE_STATIC_ASSERT_V(rank(stensor) >= rank(Tiler_MN{}), "Rank of tensor to be partitioned too small."); + + // Tile the stensor and compute the (src-thr, src-val) -> (ref-thr, ref-val) layout + return tile2thrfrg(zipped_divide(stensor,Tiler_MN{}), right_inverse(AtomLayoutRef{}).compose(AtomLayoutSrc{})); + } + + // Tile a tensor or a layout from shape + // (M,N,...) + // to shape + // ((ThrV,ThrX),FrgV,(RestM,RestN,...)) + // where + // ThrV: The threads local to a COPY_ATOM Dst. + // ThrX: The threads tiled across COPY_ATOMs Dst. + // FrgV: The values local to a COPY_ATOM Dst. + // RestM: The values tiled in M. + // RestN: The values tiled in N. + template + CUTE_HOST_DEVICE constexpr static + auto + tidfrg_D(DTensor&& dtensor) + { + CUTE_STATIC_ASSERT_V(rank(dtensor) >= rank(Tiler_MN{}), "Rank of tensor to be partitioned too small."); + + // Tile the dtensor and compute the (dst-thr, dst-val) -> (ref-thr, ref-val) layout + return tile2thrfrg(zipped_divide(dtensor,Tiler_MN{}), right_inverse(AtomLayoutRef{}).compose(AtomLayoutDst{})); + } + + // Tile a tensor or a layout from shape + // ((TileM,TileN,...), (RestM,RestN,...)) + // to shape + // ((ThrV,ThrX),FrgV,(RestM,RestN,...)) + template + CUTE_HOST_DEVICE constexpr static + auto + tile2thrfrg(Tensor&& tensor, Ref2TrgLayout const& ref2trg) + { + // Take the thrs/vals that the atom is interested in + // NOTE: Assumes the AtomNumThr are contiguous and identity within TiledThrID + auto atom_layout_TV = zipped_divide(TiledLayout_TV{}, make_shape(AtomNumThr{}, AtomNumVal{})); + // ((atom_tid,atom_val),(rest_tid,rest_val)) -> (m,n) + + // Transform to the trg layout + auto trg_layout_TV = atom_layout_TV.compose(ref2trg, _); + // ((trg_tid,trg_val),(rest_tid,rest_val)) -> (m,n) + + // Transform the thrs mode from thrid to thr_idx + // NOTE: Assumes the AtomNumThr are contiguous and identity within TiledThrID + auto thrval2mn = coalesce(zip(trg_layout_TV), Shape<_1,Shape<_1,_1>>{}); + // ((trg_tid,rest_tid),(trg_val,rest_val)) -> (m,n) + + /// ================== + + // Transform the tile mode + auto tv_tensor = tensor.compose(thrval2mn, _); + // ((thrid,val),(RestM,RestN,...)) + + // Unfold and return + return tv_tensor(make_coord(_,_), _); + } + + // retile_S and retile_D assume they are working with the reference layout -- they are the same + template + CUTE_HOST_DEVICE constexpr static + auto + retile(Tensor&& tensor) + { + constexpr int R = remove_cvref_t::rank; + // Assert that AtomLayoutSrc|Dst is identity so we can skip the Ref transformation + + // Assume the first size<0>(tensor) elements are the first val_ids in TiledLayout_TV. + // Then, we only need the shape+layout of those size<0>(tensor) elements in TiledLayout_TV + // and that shape is what we gather from the other modes of tensor + + auto V = size<0>(tensor); + + auto frg_layout_mn = upcast(right_inverse(TiledLayout_TV{}).with_shape(shape(Tiler_MN{}))); + // (m,n) -> v_idx -- The shape and order of the V inside of TiledLayout_TV + + auto frg_layout_v = zipped_divide(logical_product(make_layout(V), right_inverse(frg_layout_mn)), make_layout(AtomNumVal{})); + // (atom_vals,rest_vals) -> (v,m,n) + + /// ======= + + // Tile the tensor for TileFrg + auto t_tensor = zipped_divide(tensor, prepend(product_each(shape(frg_layout_mn)), V)); + // ((TileV,TileM,TileN,...),(1,RestM,RestN,...)) + + // Transform the tile mode + auto v_tensor = t_tensor.compose(frg_layout_v, _); + // ((atom_vals,rest_vals),(1,RM,RN,...)) + + // Unfold and return + return v_tensor(_, append(Int<0>{},_)); + } + + CUTE_HOST_DEVICE constexpr static + auto + get_layoutS_TV() + { + // (M,N) -> (M,N) + auto ref_S = make_layout(make_shape(shape(Tiler_MN{}), Int<1>{})); + // (thr_idx,val_idx) -> (M,N) + return tile2thrfrg(ref_S, right_inverse(AtomLayoutRef{}).compose(AtomLayoutSrc{}))(_,_,Int<0>{}); + } + + CUTE_HOST_DEVICE constexpr static + auto + get_layoutS_MN() + { + // (thr_idx,val_idx) -> (M,N) + auto layoutS_TV = get_layoutS_TV(); + // (M,K) -> (thr_idx,val_idx) + auto layoutS_MK = right_inverse(layoutS_TV).with_shape(shape(Tiler_MN{})); + + // athrid = (v,m,k) -> thr_idx + auto thrID_S = make_layout(size<0>(TiledLayout_TV{})); + + return cute::make_tuple(layoutS_MK, thrID_S); + } + + CUTE_HOST_DEVICE constexpr static + auto + get_layoutD_TV() + { + // (M,N) -> (M,N) + auto ref_D = make_layout(make_shape(shape(Tiler_MN{}), Int<1>{})); + // (thr_idx,val_idx) -> (M,N) + return tile2thrfrg(ref_D, right_inverse(AtomLayoutRef{}).compose(AtomLayoutDst{}))(_,_,Int<0>{}); + } + + CUTE_HOST_DEVICE constexpr static + auto + get_layoutD_MN() + { + // (thr_idx,val_idx) -> (M,N) + auto layoutD_TV = get_layoutD_TV(); + // (M,K) -> (thr_idx,val_idx) + auto layoutD_MK = right_inverse(layoutD_TV).with_shape(shape(Tiler_MN{})); + + // athrid = (v,m,k) -> thr_idx + auto thrID_D = make_layout(size<0>(TiledLayout_TV{})); + + return cute::make_tuple(layoutD_MK, thrID_D); + } + + template ::value)> + CUTE_HOST_DEVICE static + auto + get_slice(ThrIdx const& thr_idx) + { + return ThrCopy(thr_idx); + } + + template ::value)> + CUTE_HOST_DEVICE static + auto + get_thread_slice(ThrIdx const& thr_idx) + { + return get_slice(thr_idx); + } +}; + +template +struct ThrCopy +{ + ThrIdx thr_idx_; + + CUTE_HOST_DEVICE + ThrCopy(ThrIdx const& thr_idx) : thr_idx_(thr_idx) {} + + template + CUTE_HOST_DEVICE + auto + partition_S(STensor&& stensor) const { + //static_assert(sizeof(typename remove_cvref_t::value_type) == sizeof(typename TiledCopy::ValType), + // "Expected ValType for tiling SrcTensor."); + auto thr_tensor = make_tensor(static_cast(stensor).data(), TiledCopy::tidfrg_S(stensor.layout())); + return thr_tensor(thr_idx_, _, repeat>(_)); + } + + template + CUTE_HOST_DEVICE + auto + partition_D(DTensor&& dtensor) const { + //static_assert(sizeof(typename remove_cvref_t::value_type) == sizeof(typename TiledCopy::ValType), + // "Expected ValType for tiling DstTensor."); + auto thr_tensor = make_tensor(static_cast(dtensor).data(), TiledCopy::tidfrg_D(dtensor.layout())); + return thr_tensor(thr_idx_, _, repeat>(_)); + } + + template + CUTE_HOST_DEVICE static + auto + retile_S(STensor&& stensor) { + // static_assert(sizeof(typename remove_cvref_t::value_type) == sizeof(typename TiledCopy::ValType), + // "Expected ValType for tiling SrcTensor."); + return make_tensor(static_cast(stensor).data(), TiledCopy::retile(stensor.layout())); + } + + template + CUTE_HOST_DEVICE static + auto + retile_D(DTensor&& dtensor) { + // static_assert(sizeof(typename remove_cvref_t::value_type) == sizeof(typename TiledCopy::ValType), + // "Expected ValType for tiling DstTensor."); + return make_tensor(static_cast(dtensor).data(), TiledCopy::retile(dtensor.layout())); + } +}; + + +template +CUTE_HOST_DEVICE +auto +make_tiled_copy_impl(Copy_Atom const& atom, + LayoutCopy_TV const&, + Tiler const&) +{ + return TiledCopy, LayoutCopy_TV, Tiler>{atom}; +} + +// +// These tile the Copy_Atom as a whole +// + +template +CUTE_HOST_DEVICE +auto +make_tiled_copy_A(Copy_Atom const& copy_atom, + TiledMMA const& mma) +{ + return make_tiled_copy_impl(copy_atom, mma.get_layoutA_TV(), make_shape(tile_size<0>(mma),tile_size<2>(mma))); +} + +template +CUTE_HOST_DEVICE +auto +make_tiled_copy_B(Copy_Atom const& copy_atom, + TiledMMA const& mma) +{ + return make_tiled_copy_impl(copy_atom, mma.get_layoutB_TV(), make_shape(tile_size<1>(mma),tile_size<2>(mma))); +} + +template +CUTE_HOST_DEVICE +auto +make_tiled_copy_C(Copy_Atom const& copy_atom, + TiledMMA const& mma) +{ + return make_tiled_copy_impl(copy_atom, mma.get_layoutC_TV(), make_shape(tile_size<0>(mma),tile_size<1>(mma))); +} + +// returns the smallest tiled copy that can retile LayoutC_TV +// for use with pipelined epilogues with subtiled stores +template +CUTE_HOST_DEVICE +auto +make_tiled_copy_C_atom(Copy_Atom const& copy_atom, + TiledMMA const& mma) +{ + // Truncate the V-layout to just the Copy_Atom, keep the V-order + auto layoutC_TV = mma.get_layoutC_TV(); + auto copy_V = Int::NumValSrc>{}; + CUTE_STATIC_ASSERT_V(copy_V <= size<1>(layoutC_TV)); + auto layout_TV = composition(layoutC_TV, make_layout(make_shape(size<0>(layoutC_TV), copy_V))); + + // Recompute tiler and restride the TV layout for the new tiler + + // Tiler -- Find the active elements in the MMA tensor and generate a tiler to extract them + // Convert to the awkward by-mode tiler to preserve the modes of the tiled MMA + auto mma_tiler = make_shape(tile_size<0>(mma),tile_size<1>(mma)); + auto mma_zeros = repeat_like(mma_tiler, Int<0>{}); + + auto tiler = transform(make_seq{}, [&](auto i) { + return filter(composition(make_layout(mma_tiler, replace(mma_zeros, Int<1>{})), layout_TV)); + }); + + // Layout_TV -- Find the (tid,vid) -> tile coord transformation + // Apply the tiler to a reference and transform the codomain + // tile_coord -> mma_coord + auto tile2mma = composition(make_layout(mma_tiler), tiler); + + // (tid,vid) -> tile_coord + auto layout_tv = composition(left_inverse(tile2mma), layout_TV); + + return make_tiled_copy_impl(copy_atom, layout_tv, tiler); +} + +/** Produce a TiledCopy from logical thread and values layouts. + * The thread and value layouts map coordinates to thr_idx and val_idx. + * The product of these layouts is taken to produce the TV layout and the Tiler. + * Useful when threads and values need very specific mappings onto coordinates + * in the target tensors. + */ +template > +CUTE_HOST_DEVICE +auto +make_tiled_copy(Copy_Atom const& copy_atom, + ThrLayout const& thr_layout = {}, // (m,n) -> thr_idx + ValLayout const& val_layout = {}) // (m,n) -> val_idx +{ + // Take the raked_products to compute the Layout_MN + // (M,N) -> (thr_idx, val_idx) + auto layout_mn = raked_product(thr_layout, val_layout); + // (thr_idx, val_idx) -> (M,N) + auto layout_tv = right_inverse(layout_mn).with_shape(make_shape(size(thr_layout), size(val_layout))); + // Tiler for extracting relevant elements + // (M,N) -> tensor coord + auto tiler = product_each(shape(layout_mn)); + +#if 0 + print("thr_layout: "); print(thr_layout); print("\n"); + print("val_layout: "); print(val_layout); print("\n"); + print("layout_mn : "); print(layout_mn); print("\n"); + print("layout_tv : "); print(layout_tv); print("\n"); + print("tiler : "); print(tiler); print("\n"); +#endif + + return make_tiled_copy_impl(copy_atom, layout_tv, tiler); +} + +/** Produce a TiledCopy from thread and value offset maps. + * The TV Layout maps threads and values to the codomain of the data_layout. + * It is verified that the intended codomain is valid within data_layout. + * Useful when threads and values don't care about owning specific coordinates, but + * care more about the vector-width and offsets between them. + */ +template +CUTE_HOST_DEVICE constexpr +auto +make_cotiled_copy(Copy_Atom const& copy_atom, + AtomTVLayout const& atom_tv_layout, // atom (thr,val) -> data addr + DataLayout const& data_layout) // coord -> data addr The target layout +{ + static_assert(is_static::value); + static_assert(is_static::value); + + // data addr -> data coord Append 1:0 so off-the-ends get the stride-0 + auto inv_data_layout = make_layout(left_inverse(data_layout), Layout<_1,_0>{}); + + // (tid,vid) -> data_coord + auto layout_tv_data = composition(inv_data_layout, atom_tv_layout); + + // Check validity + CUTE_STATIC_ASSERT_V(coalesce(composition(data_layout, layout<1>(layout_tv_data))) == coalesce(layout<1>(atom_tv_layout)), + "The memory pointed to by AtomTVLayout does not exist in the DataLayout."); + +#if 0 + if (thread0()) { + print("data_layout : "); print(data_layout); print("\n"); + print("atom_tv_layout : "); print(atom_tv_layout); print("\n"); + print("layout_tv_data : "); print(layout_tv_data); print("\n"); + } +#endif + + // + // Tiler -- Find the active elements in the DATA tensor and generate a tiler to extract them + // + + // Convert to the awkward by-mode tiler to preserve the modes of the tiled DATA + auto flat_data_shape = product_each(shape(data_layout)); + auto flat_data_zeros = repeat(Int<0>{}); + + auto tiler = transform(make_seq{}, [&](auto i) { + return filter(composition(make_layout(flat_data_shape, replace(flat_data_zeros, Int<1>{})), layout_tv_data)); + }); + + // + // Layout_TV -- Find the (tid,vid) -> tile coord transformation + // + + // Apply the tiler to a reference and transform the codomain + // tile_coord -> data_coord + auto tile2data = composition(make_layout(flat_data_shape), tiler); + + // (tid,vid) -> tile_coord + auto layout_tv = composition(left_inverse(tile2data), layout_tv_data); + +#if 0 + if (thread0()) { + print("tiler : "); print(tiler); print("\n"); + print("tile2data : "); print(tile2data); print("\n"); + print("layout_tv : "); print(layout_tv); print("\n"); + } +#endif + + return make_tiled_copy_impl(copy_atom, layout_tv, tiler); +} + +// Make a TiledCopy out of the copy_atom that matches the Src-Layout of tiled_copy +template +CUTE_HOST_DEVICE +auto +make_tiled_copy_S(Copy_Atom const& copy_atom, + TiledCopy const& tiled_copy) +{ + return make_tiled_copy_impl(copy_atom, tiled_copy.get_layoutS_TV(), typename TiledCopy::Tiler_MN{}); +} + +// Make a TiledCopy out of the copy_atom that matches the Dst-Layout of tiled_copy +template +CUTE_HOST_DEVICE +auto +make_tiled_copy_D(Copy_Atom const& copy_atom, + TiledCopy const& tiled_copy) +{ + return make_tiled_copy_impl(copy_atom, tiled_copy.get_layoutD_TV(), typename TiledCopy::Tiler_MN{}); +} + +// +// Size +// + +// The logical size of a TileCopy +template +CUTE_HOST_DEVICE constexpr +auto +tile_size(TiledCopy const&) +{ + return size(typename TiledCopy::Tiler_MN{}); +} + +// The number of threads involved in a TiledCopy +template +CUTE_HOST_DEVICE constexpr +auto +size(TiledCopy const&) +{ + return typename TiledCopy::TiledNumThr{}; +} + +// +// Display utilities +// + +template +CUTE_HOST_DEVICE +void +print(Copy_Atom, T> const&) +{ + using Atom = Copy_Atom, T>; + print("Copy_Atom\n"); + print(" ThrID: "); print(typename Atom::ThrID{}); print("\n"); + print(" ValLayoutSrc: "); print(typename Atom::ValLayoutSrc{}); print("\n"); + print(" ValLayoutDst: "); print(typename Atom::ValLayoutDst{}); print("\n"); + print(" ValLayoutRef: "); print(typename Atom::ValLayoutRef{}); print("\n"); + print(" ValueType: "); print(sizeof_bits::value); print("b\n"); +} + +template +CUTE_HOST_DEVICE +void +print(TiledCopy const& copy, char const* pad = "") +{ + using Copy = TiledCopy; + print("TiledCopy\n"); + print(" Tiler_MN: "); print(typename Copy::Tiler_MN{}); print("\n"); + print(" TiledLayout_TV: "); print(typename Copy::TiledLayout_TV{}); print("\n"); + print(static_cast(copy)); +} + +template +CUTE_HOST_DEVICE +void +print(ThrCopy const& thr_copy) +{ + print("ThrCopy\n"); + print(" ThrIdx: "); print(thr_copy.thr_idx_); print("\n"); + print(TiledCopy{}); +} + +// TiledCopy to LaTeX TikZ +template +CUTE_HOST_DEVICE +auto +print_latex(TiledCopy const& copy, + TikzColorFn color = {}) // lambda(thr_idx,val_idx) -> tikz color string +{ + auto [layoutS_MN, thrID_S] = copy.get_layoutS_MN(); + auto [layoutD_MN, thrID_D] = copy.get_layoutD_MN(); + + print_latex_copy(layoutS_MN, thrID_S, + layoutD_MN, thrID_D); +} + +// MNK Copy Layout to LaTeX TikZ +template +CUTE_HOST_DEVICE +void +print_latex_copy(LayoutS const& S, ThrIDS const& TS, // (m,n) -> (tid,vid) and tid -> thr_idx + LayoutD const& D, ThrIDD const& TD, // (m,n) -> (tid,vid) and tid -> thr_idx + TikzColorFn color = {}) // lambda(thr_idx,val_idx) -> tikz color string +{ + CUTE_STATIC_ASSERT_V(rank(S) == Int<2>{}); + CUTE_STATIC_ASSERT_V(rank(D) == Int<2>{}); + + assert(size<0>(S) == size<0>(D)); + assert(size<1>(S) == size<1>(D)); + + // Commented prints + printf("%% LayoutS: "); print(S); printf("\n"); + printf("%% ThrIDS : "); print(TS); printf("\n"); + printf("%% LayoutD: "); print(D); printf("\n"); + printf("%% ThrIDD : "); print(TD); printf("\n\n"); + + // Header + printf("\\documentclass[convert]{standalone}\n" + "\\usepackage{tikz}\n\n" + "\\begin{document}\n" + "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},every node/.style={minimum size=1cm, outer sep=0pt}]\n\n"); + + // S starting at 0,0 + for (int i = 0; i < size<0>(S); ++i) { + for (int j = 0; j < size<1>(S); ++j) { + int thrid = S(i,j) % size(TS); + int val_idx = S(i,j) / size(TS); + int thr_idx = TS(thrid); + + printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", + color(thr_idx, val_idx), + i, j, + thr_idx, val_idx); + } + } + // Grid + printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (%d,%d) grid (%d,%d);\n\n", + 0, 0, int(size<0>(S)), int(size<1>(S))); + // S Labels + for (int i = 0, j = -1; i < size<0>(S); ++i) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, i); + } + for (int i = -1, j = 0; j < size<1>(S); ++j) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, j); + } + + // D starting at 0,size<1>(S)+3 + for (int i = 0; i < size<0>(D); ++i) { + for (int j = 0; j < size<1>(D); ++j) { + int thrid = D(i,j) % size(TD); + int val_idx = D(i,j) / size(TD); + int thr_idx = TD(thrid); + + printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", + color(thr_idx, val_idx), + i, j + size<1>(S) + 3, + thr_idx, val_idx); + } + } + // Grid + printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (%d,%d) grid (%d,%d);\n\n", + 0, int(size<1>(S)+3), int(size<0>(D)), int(size<1>(D)+size<1>(S)+3)); + // D Labels + for (int i = 0, j = size<1>(D); i < size<0>(D); ++i) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j + size<1>(S) + 3, i); + } + for (int i = -1, j = 0; j < size<1>(D); ++j) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j + size<1>(S) + 3, j); + } + + // Footer + printf("\\end{tikzpicture}\n" + "\\end{document}\n"); +} + +} // end namespace cute + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#include +#include +#include +#include + +// Config +#if (__CUDACC_VER_MAJOR__ >= 12) +# define CUTE_COPY_ATOM_TMA_SM90_ENABLED +#endif + +#if defined(CUTE_COPY_ATOM_TMA_SM90_ENABLED) +#include +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cute/atom/copy_traits.hpp b/include/cute/atom/copy_traits.hpp new file mode 100644 index 0000000000..ac746a64e1 --- /dev/null +++ b/include/cute/atom/copy_traits.hpp @@ -0,0 +1,162 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include + +namespace cute +{ + +/** + * concept Copy_Traits + * { + * using ThrID = // Logical thread id (tid) -> tidx + * + * using SrcLayout = // (Logical src thread id (tid), Logical src value id (vid)) -> bit + * using DstLayout = // (Logical dst thread id (tid), Logical dst value id (vid)) -> bit + * using RefLayout = // (Logical ref thread id (tid), Logical ref value id (vid)) -> bit + * }; + * + * The abstract bit ordering of the Copy_Traits (the codomain of SrcLayout, DstLayout, and RefLayout) + * is arbitrary and only used to construct maps + * (ref-tid,ref-vid) -> (src-tid,src-vid) + * (ref-tid,ref-vid) -> (dst-tid,dst-vid) + * in TiledCopy. The Layout_TV in TiledCopy is in accordance with the RefLayout of a Traits, then mapped to + * the Src or Dst (tid,vid) representation on demand. + * + */ + +template +struct Copy_Traits +{ + static_assert(dependent_false, "Copy_Traits not implemented for this CopyOperation."); +}; + +template +struct Copy_Traits> +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout::value>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout::value>>>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +template +struct Copy_Traits> +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, Stride<_0,_0>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, Stride<_0,_0>>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +// Extract a CPY_Op from a CPY_Traits +template +struct CPY_Op {}; + +template +struct CPY_Op> { + using type = CPY_Op_Arg; +}; + +// +// Generic copy_unpack for common argument-based Copy_Traits +// + +template +CUTE_HOST_DEVICE constexpr +void +copy_unpack(AnyCPYTraits const&, + Tensor const& src, + Tensor & dst) +{ + using CopyOp = typename CPY_Op::type; + using RegistersSrc = typename CopyOp::SRegisters; + using RegistersDst = typename CopyOp::DRegisters; + using RegTypeSrc = typename remove_extent::type; + using RegTypeDst = typename remove_extent::type; + constexpr int RegNumSrc = extent::value; + constexpr int RegNumDst = extent::value; + + Tensor rS = recast(src); + Tensor rD = recast(dst); + + CUTE_STATIC_ASSERT_V(size(rS) == Int{}, + "Copy_Traits: src failed to vectorize into registers. Layout is incompatible with this CopyOp."); + CUTE_STATIC_ASSERT_V(size(rD) == Int{}, + "Copy_Traits: dst failed to vectorize into registers. Layout is incompatible with this CopyOp."); + + detail::explode(detail::CallCOPY{}, + rS, make_int_sequence{}, + rD, make_int_sequence{}); +} + +// Accept mutable temporaries +template +CUTE_HOST_DEVICE constexpr +void +copy_unpack(AnyCPYTraits const& traits, + Tensor const& src, + Tensor && dst) +{ + copy_unpack(traits, src, dst); +} + +namespace detail { + +template +constexpr bool is_prefetch = false; + +template +constexpr bool is_prefetch> = is_same_v; + +} // end namespace detail + + +} // end namespace cute diff --git a/include/cute/atom/copy_traits_sm50.hpp b/include/cute/atom/copy_traits_sm50.hpp new file mode 100644 index 0000000000..7a693805e6 --- /dev/null +++ b/include/cute/atom/copy_traits_sm50.hpp @@ -0,0 +1,75 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include + +#include + +namespace cute +{ + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride<_64, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,Shape <_32, _2>>, + Stride,Stride< _1, _64>>>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride<_64, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, Shape<_32, _2>>, + Stride,Stride< _1, _256>>>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +} // end namespace cute diff --git a/include/cute/atom/copy_traits_sm75.hpp b/include/cute/atom/copy_traits_sm75.hpp new file mode 100644 index 0000000000..9ad82c6174 --- /dev/null +++ b/include/cute/atom/copy_traits_sm75.hpp @@ -0,0 +1,143 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include + +#include + +namespace cute +{ + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout,_128>, + Stride, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, + Stride<_32, _1>>; + + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout,_128>, + Stride, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_32,Stride< _1,_1024>>>; + + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride<_128, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>, + Stride<_32,Stride< _1,_1024>>>; + + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout,_128>, + Stride, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,Shape <_16, _2>>, + Stride,Stride< _1,_128>>>; + + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout,_128>, + Stride, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,Shape <_16, _2, _2>>, + Stride,Stride< _1,_128,_1024>>>; + + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, + Stride<_128, _1>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout,Shape <_16, _2, _4>>, + Stride,Stride< _1,_128,_1024>>>; + + // Reference map from (thr,val) to bit + using RefLayout = DstLayout; +}; + +} // end namespace cute diff --git a/include/cute/atom/copy_traits_sm80.hpp b/include/cute/atom/copy_traits_sm80.hpp new file mode 100644 index 0000000000..3795f52a89 --- /dev/null +++ b/include/cute/atom/copy_traits_sm80.hpp @@ -0,0 +1,167 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include + +#include + +namespace cute +{ + +template +struct Copy_Traits> +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout::value>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout::value>>>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +template +struct Copy_Traits> +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout::value>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout::value>>>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +template +struct Copy_Traits> +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout::value>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout::value>>>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // Predicate value: true = load, false = zfill + bool pred = true; + + // Construct a zfill variant with a given predicate value + CUTE_HOST_DEVICE constexpr + Copy_Traits> + with(bool pred) const { + return {pred}; + } + + // Overload copy_unpack for zfill variant to pass the predicate into the op + template + CUTE_HOST_DEVICE friend constexpr + void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + static_assert(is_gmem::value, "Expected gmem source for cp.async."); + static_assert(is_smem

::value, "Expected smem destination for cp.async."); + + Tensor rS = recast(src); + Tensor rD = recast(dst); + + CUTE_STATIC_ASSERT_V(size(rS) == Int<1>{}, + "In CopyAtom, src layout doesn't vectorize into registers. This src layout is incompatible with this tiled copy."); + CUTE_STATIC_ASSERT_V(size(rD) == Int<1>{}, + "In CopyAtom, dst layout doesn't vectorize into registers. This dst layout is incompatible with this tiled copy."); + + SM80_CP_ASYNC_CACHEALWAYS_ZFILL::copy(rS[0], rD[0], traits.pred); + } +}; + +template +struct Copy_Traits> +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout::value>>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout::value>>>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // Predicate value: true = load, false = zfill + bool pred = true; + + // Construct a zfill variant with a given predicate value + CUTE_HOST_DEVICE constexpr + Copy_Traits> + with(bool pred) const { + return {pred}; + } + + // Overload copy_unpack for zfill variant to pass the predicate into the op + template + CUTE_HOST_DEVICE friend constexpr + void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + static_assert(is_gmem::value, "Expected gmem source for cp.async."); + static_assert(is_smem::value, "Expected smem destination for cp.async."); + + Tensor rS = recast(src); + Tensor rD = recast(dst); + + CUTE_STATIC_ASSERT_V(size(rS) == Int<1>{}, + "In CopyAtom, src layout doesn't vectorize into registers. This src layout is incompatible with this tiled copy."); + CUTE_STATIC_ASSERT_V(size(rD) == Int<1>{}, + "In CopyAtom, dst layout doesn't vectorize into registers. This dst layout is incompatible with this tiled copy."); + + SM80_CP_ASYNC_CACHEGLOBAL_ZFILL::copy(rS[0], rD[0], traits.pred); + } +}; + +} // end namespace cute diff --git a/include/cute/atom/copy_traits_sm90.hpp b/include/cute/atom/copy_traits_sm90.hpp new file mode 100644 index 0000000000..f9590848af --- /dev/null +++ b/include/cute/atom/copy_traits_sm90.hpp @@ -0,0 +1,132 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include +#include + +#include + +namespace cute +{ + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = typename Copy_Traits::DstLayout; + // Map from (dst-thr,dst-val) to bit + using DstLayout = typename Copy_Traits::SrcLayout; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = typename Copy_Traits::DstLayout; + // Map from (dst-thr,dst-val) to bit + using DstLayout = typename Copy_Traits::SrcLayout; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = typename Copy_Traits::DstLayout; + // Map from (dst-thr,dst-val) to bit + using DstLayout = typename Copy_Traits::SrcLayout; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = typename Copy_Traits::DstLayout; + // Map from (dst-thr,dst-val) to bit + using DstLayout = typename Copy_Traits::SrcLayout; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = typename Copy_Traits::DstLayout; + // Map from (dst-thr,dst-val) to bit + using DstLayout = typename Copy_Traits::SrcLayout; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +template <> +struct Copy_Traits +{ + // Logical thread id to thread idx (warp) + using ThrID = Layout<_32>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = typename Copy_Traits::DstLayout; + // Map from (dst-thr,dst-val) to bit + using DstLayout = typename Copy_Traits::SrcLayout; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; +}; + +} // end namespace cute diff --git a/include/cute/atom/copy_traits_sm90_im2col.hpp b/include/cute/atom/copy_traits_sm90_im2col.hpp new file mode 100644 index 0000000000..54f76073b1 --- /dev/null +++ b/include/cute/atom/copy_traits_sm90_im2col.hpp @@ -0,0 +1,940 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +/*! \file + \brief im2col make_tma_copy +*/ + +#include "cute/arch/copy_sm90.hpp" +#include "cute/arch/copy_sm90_desc.hpp" +#include "cute/tensor.hpp" + +#include "cute/algorithm/prefetch.hpp" +#include "cutlass/fast_math.h" +#include "cutlass/cuda_host_adapter.hpp" + +namespace cute +{ + +// Utility for unpacking TMA_LOAD_IM2COL arguments into a CopyOp +template +struct TMA_LOAD_IM2COL_Unpack +{ + /// Copy from src to dst. + /// + /// @param traits Copy traits created with a TMA descriptor that + /// correctly matches the input tensor and other convolution + /// parameters. + /// + /// @param src Tile of the im2col-transformed coordinate tensor + /// (result of get_tma_tensor), representing the global-memory + /// tensor from which to load. + /// + /// @param dst Shared memory tile, into which to load. + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, // tile of the transformed global activation (A) tensor + Tensor & dst) // shared memory tile + { + auto src_coord_offset = src(Int<0>{}); + auto src_coord_cwhdn_offset_srt = flatten(src_coord_offset); + // Interpret the TMA IM2COL coordinate as (c, ([w,h,d]), n, ([s,r,t])) + CUTE_STATIC_ASSERT_V(rank(src_coord_offset) == _4{}); + CUTE_STATIC_ASSERT_V(rank<1>(src_coord_offset) == rank<3>(src_coord_offset)); + + if constexpr (detail::is_prefetch) { + return detail::explode_tuple(detail::CallCOPY{}, + traits.opargs_, tuple_seq{}, + src_coord_cwhdn_offset_srt, tuple_seq{}); + } else { + static_assert(is_smem::value, "SM90_TMA_LOAD_IM2COL requires the destination be shared memory."); + void* dst_ptr = cute::raw_pointer_cast(dst.data()); + return detail::explode_tuple(detail::CallCOPY{}, + traits.opargs_, tuple_seq{}, + make_tuple(dst_ptr), seq<0>{}, + src_coord_cwhdn_offset_srt, tuple_seq{}); + } + } +}; + +// Copy_Traits for SM90 im2col TMA load comes in two layers. +// +// 1. Copy_Traits +// 2. Copy_Traits +// +// Copy_Traits +// is the "outer" layer. It has a TMA descriptor, +// but no barrier ("tma_mbar"), so it's "nonexecutable." +// One calls its "with" member function with a barrier, +// to get an executable "inner"-layer +// Copy_Traits object. +// That object's "copy_unpack" member function +// actually invokes im2col TMA load. + +struct SM90_TMA_LOAD_IM2COL_OP : SM90_TMA_LOAD_IM2COL {}; + +/// @brief Non-executable specialization of Copy_Traits for SM90 +/// im2col TMA load, with TMA descriptor but no barrier. +/// +/// Use `.with(memory_barrier)` to construct an executable version. +template +struct Copy_Traits +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + Im2ColTmaDescriptor tma_desc_; + TMATensor tma_tensor_; + + CUTE_HOST_DEVICE constexpr + Im2ColTmaDescriptor const* + get_tma_descriptor() const + { + return &tma_desc_; + } + + template + CUTE_HOST_DEVICE constexpr + TMATensor const + get_tma_tensor(GShape const&) const + { + return tma_tensor_; + } + + /// @brief Get an executable specialization. + /// + /// Copy_Traits specializations with SM90_TMA_LOAD_IM2COL are not + /// directly executable. Instead, call this "with" member function + /// to get an executable specialization. "Executable" means that + /// @c copy_unpack works. + /// + /// @param tma_mbar Memory barrier for synchronization + /// + /// @param multicast_mask Multicast mask (unused; only exists + /// for interface compatibility with the actual multicast Copy_Traits) + /// + /// @return Executable specialization of @c Copy_Traits + CUTE_HOST_DEVICE constexpr + Copy_Traits + with(uint64_t& tma_mbar, [[maybe_unused]] uint16_t const& multicast_mask = 0) const + { + return {{}, {&tma_desc_, &tma_mbar}}; + } + + // Copy_Traits specializations with SM90_TMA_LOAD_IM2COL + // are not directly executable. Instead, call .with + // to get an executable specialization. + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) = delete; +}; + +/// @brief Executable specialization of Copy_Traits for SM90 im2col +/// TMA load, with TMA descriptor and barrier. +template +struct Copy_Traits + : TMA_LOAD_IM2COL_Unpack +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_TMA_LOAD_IM2COL arguments + tuple< + Im2ColTmaDescriptor const*, + uint64_t* // smem mbarrier + > const opargs_; +}; + +template +struct Copy_Traits + : TMA_LOAD_IM2COL_Unpack +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_TMA_LOAD_IM2COL::PREFETCH arguments + tuple const opargs_; + + CUTE_HOST_DEVICE + Copy_Traits(Copy_Traits const& traits) + : opargs_({&traits.tma_desc_}) {} +}; + +////////////////////////////////////////////////////////////////////////////// +///////////////////////////// TMA_LOAD_MULTICAST ///////////////////////////// +////////////////////////////////////////////////////////////////////////////// + +struct SM90_TMA_LOAD_IM2COL_MULTICAST_OP : SM90_TMA_LOAD_IM2COL_MULTICAST {}; + +/// @brief Non-executable specialization of Copy_Traits for SM90 +/// im2col TMA load, with TMA descriptor but no barrier or multicast +/// mask. +/// +/// Use `.with(memory_barrier)` to construct an executable version. +template +struct Copy_Traits +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + Im2ColTmaDescriptor tma_desc_; + TMATensor tma_tensor_; + + CUTE_HOST_DEVICE constexpr + Im2ColTmaDescriptor const* + get_tma_descriptor() const { + return &tma_desc_; + } + + template + CUTE_HOST_DEVICE constexpr + TMATensor const + get_tma_tensor(GShape const&) const + { + return tma_tensor_; + } + + /// @brief Get an executable specialization. + /// + /// Copy_Traits specializations with SM90_TMA_LOAD_IM2COL_MULTICAST + /// are not directly executable. Instead, call this "with" member + /// function to get an executable specialization. "Executable" + /// means that @c copy_unpack works. + /// + /// @param tma_mbar Memory barrier for synchronization + /// + /// @param multicast_mask Multicast mask (defaults to a single CTA) + /// + /// @return Executable specialization of @c Copy_Traits + CUTE_HOST_DEVICE constexpr + Copy_Traits + with(uint64_t& tma_mbar, uint16_t const& multicast_mask) const { + return {{}, {&tma_desc_, &tma_mbar, multicast_mask}}; + } + + // Copy_Traits specializations with SM90_TMA_LOAD_IM2COL_MULTICAST + // are not directly executable. Instead, call .with to get an + // executable specialization. + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) = delete; +}; + +/// @brief Executable specialization of Copy_Traits for SM90 multicast +/// im2col TMA load, with TMA descriptor, barrier, and multicast mask. +template +struct Copy_Traits + : TMA_LOAD_IM2COL_Unpack +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit. + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_TMA_LOAD_IM2COL_MULTICAST arguments + tuple< + Im2ColTmaDescriptor const*, + uint64_t*, // smem mbarrier + uint16_t // multicast mask + > const opargs_; +}; + +////////////////////////////////////////////////////////////////////////////// +///////////////////////////// TMA_STORE IM2COL//////////////////////////////// +////////////////////////////////////////////////////////////////////////////// + +// The executable SM90_TMA_STORE_IM2COL with tma_desc +template +struct Copy_Traits +{ + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_TMA_STORE_IM2COL arguments + Im2ColTmaDescriptor tma_desc_; + TMATensor tma_tensor_; + + // Return TmaDescriptor/TensorMap + CUTE_HOST_DEVICE constexpr + Im2ColTmaDescriptor const* + get_tma_descriptor() const { + return &tma_desc_; + } + + template + CUTE_HOST_DEVICE constexpr + TMATensor const + get_tma_tensor(GShape const&) const + { + return tma_tensor_; + } + + // This is the copy_unpack dispatch for this Copy_Traits + // Src needs to be a smem tensor + // Dst needs to be a gmem tensor with TmaCoordIterator .data() + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + static_assert(is_smem::value, "Expected smem src for SM90_TMA_STORE_IM2COL"); + + void const* const desc_ptr = &(traits.tma_desc_); + void const* const src_ptr = cute::raw_pointer_cast(src.data()); + auto dst_coord = flatten(take<0,3>(dst(Int<0>{}))); + + return detail::explode_tuple(detail::CallCOPY{}, + make_tuple(desc_ptr, src_ptr), seq<0,1>{}, + dst_coord, tuple_seq{}); + } +}; + +namespace detail { + +/// @brief Creates a TMA descriptor for im2col TMA load. +/// +/// @param tensor_cwhdn Global activation tensor (A matrix of Fprop). +/// This is the original (not im2col-transformed) tensor in global +/// memory. +/// +/// @param slayout Rank 2 (M,K) shared memory layout of the activation +/// tensor. Here, K is "GEMM K," not the filter tensor's mode of +/// the same name. +////// +/// @param traversal_stride Traversal strides convolution parameter +////// +/// Each of padding_shape, traversal_stride, and dilation_shape is a +/// tuple whose size is the number of spatial modes (e.g., 3 for a 5-D +/// convolution). +/// +/// @return TMA descriptor for im2col TMA load +template +CUTE_HOST +auto +make_im2col_tma_copy_desc( + Tensor const& tensor_cwhdn, // (C,W,H,D,N) + uint32_t range_c, // TILE_C + uint32_t range_whdn, // TILE_WHDN + SmemSwizzle const& smem_swizzle, // Swizzle + TMALayout const& tma_layout_vt, // TMA layout + LowerCornerStride const& lower_corner_whd, // WHD offset of the "base pointer" + UpperCornerStride const& upper_corner_whd, // WHD upper corner + LowerPaddingStride const& lower_padding_whd, // WHD lower padding + UpperPaddingStride const& upper_padding_whd, // WHD upper padding + TraversalStride const& stride_whd, // WHD traversal stride + LowerSRTStride const& lower_srt, // SRT offset of the "base pointer" + DilationStride const& stride_srt, // SRT stride - dilation + TMA::DescriptorAuxParams const& aux_params = {}) +{ + static_assert(is_gmem::value, "Tensor must point to GPU global memory."); + using value_type = typename EngineA::value_type; + + constexpr uint32_t num_total_modes = LayoutA::rank; + constexpr int num_spatial_modes = num_total_modes - 2; + + // Gmem starting address + void* gmem_address = (void*) raw_pointer_cast(tensor_cwhdn.data()); + + // Gmem extents are just the tensor shape + cute::array gmem_prob_shape = {1,1,1,1,1}; + for_each(make_seq{}, [&](auto i) { + gmem_prob_shape[i] = static_cast(shape(tensor_cwhdn)); + }); + + // Gmem strides are byte strides of the activation tensor in CWHDN order + cute::array gmem_prob_stride = {0,0,0,0,0}; + for_each(make_seq{}, [&](auto i) { + gmem_prob_stride[i] = sizeof(value_type) * stride(tensor_cwhdn); + }); + + // Traversal strides are a function of the dilation shape + // corresponding to spatial (WHD) modes. + cute::array tma_traversal_strides = {1,1,1,1,1}; + for_each(make_seq{}, [&](auto i) { + tma_traversal_strides[i+1] = static_cast(get(stride_whd)); + }); + + cute::array tma_lower_corner{}; + for_each(make_seq{}, [&](auto i) { + tma_lower_corner[i] = static_cast(get(lower_corner_whd)); + }); + + cute::array tma_upper_corner{}; + for_each(make_seq{}, [&](auto i) { + tma_upper_corner[i] = static_cast(get(upper_corner_whd)); + }); + + Im2ColTmaDescriptor tma_desc; + +#if (__CUDACC_VER_MAJOR__ >= 12) + + CUtensorMapDataType tma_format = TMA::to_CUtensorMapDataType(); + CUtensorMapInterleave tma_interleave = CU_TENSOR_MAP_INTERLEAVE_NONE; + CUtensorMapL2promotion tma_l2Promotion = to_CUtensorMapL2promotion(aux_params.l2promo_); + CUtensorMapFloatOOBfill tma_oob_fill = to_CUtensorMapFloatOOBfill(aux_params.oobfill_); + TMA::SmemSwizzleBits swizzle_bits = detail::get_tma_swizzle_bits(smem_swizzle); + TMA::SmemSwizzleBase swizzle_base = detail::get_tma_swizzle_base(smem_swizzle); + CUtensorMapSwizzle tma_swizzle = TMA::to_CUtensorMapSwizzle(swizzle_bits, swizzle_base); + + CUresult encode_result = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeIm2col)( + &tma_desc, + tma_format, + num_total_modes, + gmem_address, + gmem_prob_shape.data(), + gmem_prob_stride.data() + 1, // gmem_prob_stride[0] implicitly sizeof(value_type) + tma_lower_corner.data(), + tma_upper_corner.data(), + range_c, + range_whdn, + tma_traversal_strides.data(), + tma_interleave, + tma_swizzle, + tma_l2Promotion, + tma_oob_fill); + + // The extra asserts help indicate the error's cause. + assert(encode_result != CUDA_ERROR_DEINITIALIZED); + assert(encode_result != CUDA_ERROR_NOT_INITIALIZED); + assert(encode_result != CUDA_ERROR_INVALID_CONTEXT); + assert(encode_result != CUDA_ERROR_INVALID_VALUE); + assert(encode_result == CUDA_SUCCESS); + +#endif // (__CUDACC_VER_MAJOR__ >= 12) + // + // Calculate gemm shapes and linearized shapes based on tma layout tiling. + // + + // Compute [w, h, d, n] + // q/p/z = (w/h/d + (upper_corner_whd - lower_corner_whd - 1)) / stride_whd + 1 + auto gemm_mn_ = cute::transform(cute::make_seq{}, [&](auto i) { + return (shape(tensor_cwhdn) + get(upper_corner_whd) - get(lower_corner_whd) - Int<1>{}) / get(stride_whd) + Int<1>{}; + }); + auto gemm_mn = append(gemm_mn_, shape(tensor_cwhdn)); + + // Compute [c, s, r, t] + // fprop/wgrad, s/r/t = 1 + (upper_padding_whd - upper_corner_whd) / stride_srt + // wgrad, s/r/t = 1 + (lower_padding_whd - lower_corner_whd) / stride_srt + auto gemm_k_ = cute::transform(cute::make_seq{}, [&](auto i) { + auto padding_size = conditional_return(get(stride_srt) > Int<0>{}, + get(upper_padding_whd) - get(upper_corner_whd), + get(lower_corner_whd) - get(lower_padding_whd)); + return Int<1>{} + padding_size / get(stride_srt); + }); + auto gemm_k = prepend(gemm_k_, shape<0>(tensor_cwhdn)); + + // For fprop/dgrad kernel, gemm_shapes is ((q, p, z, n), (c, s, r, t)) + // For wgrad kernel, gemm_shapes is ((c, s, r, t), (q, p, z, n)) + auto gemm_shapes_common = make_shape( + transform_leaf(gemm_mn, [](auto s) { + return conditional_return(cute::is_static{}, s, cutlass::FastDivmod(s)); + }), + gemm_k); + auto gemm_shapes = make_shape( + basis_get(stride<0,1>(tma_layout_vt), gemm_shapes_common), + basis_get(stride<0,0>(tma_layout_vt), gemm_shapes_common)); + + // For fprop/dgrad kernel, linearized shapes is (whdn, (c, s, r, t)) + // For wgrad kernel linearized shapes is ((c, s, r, t), whdn) + auto linear_shapes_common = make_shape(size(gemm_mn), gemm_k); + auto linear_shapes = make_shape( + basis_get(stride<0,1>(tma_layout_vt), linear_shapes_common), + basis_get(stride<0,0>(tma_layout_vt), linear_shapes_common)); + + // + // Calculate gmem basis stride based on tma layout tiling. + // + + auto tma_basis_scale = make_shape(Int<1>{}, stride_whd, Int<1>{}, stride_srt); + auto tma_basis = elem_scale(tma_basis_scale, make_basis_like(tma_basis_scale)); + + auto gbasis_strides_common = make_stride( + append(get<1>(tma_basis), get<2>(tma_basis)), + prepend(get<3>(tma_basis), get<0>(tma_basis))); // ((w,h,d,n),(c,s,r,t)) + auto gbasis_strides = make_stride( + basis_get(stride<0,1>(tma_layout_vt), gbasis_strides_common), + basis_get(stride<0,0>(tma_layout_vt), gbasis_strides_common)); + + // + // Create tma tensor + // + + auto lower_corner = make_arithmetic_tuple(Int<0>{}, lower_corner_whd, Int<0>{}, lower_srt); + + auto tensor_multimode = make_tensor(ArithmeticTupleIterator(lower_corner), gemm_shapes, gbasis_strides); + auto tensor_linear = make_identity_tensor(linear_shapes); + auto tma_tensor = make_tensor(tensor_multimode.data(), composition( + tensor_multimode.layout(), + tensor_linear(Int<0>{}), + tensor_linear.layout())); + + return cute::make_tuple(tma_desc, tma_tensor); +} + +template +CUTE_HOST_RTC +auto +make_tma_atom_im2col(CopyOp, + Tensor const& gtensor, // Full GMEM Tensor: ((w, h, d, n), c) + SLayout const& slayout, // CTA Tile of SMEM, potentially swizzled + int32_t const& num_multicast, // The number of CTAs involved in multicasting + Layout const& cta_v_map, // V: CTA val idx -> gmem mode + LowerCornerStride const& lower_corner_whd, + UpperCornerStride const& upper_corner_whd, + LowerPaddingStride const& lower_padding_whd, + UpperPaddingStride const& upper_padding_whd, + TraversalStride const& stride_whd, // traversal stride + LowerSRTStride const& lower_srt, + DilationStride const& stride_srt, // dilation + TMA::DescriptorAuxParams const& aux_params = {}) +{ + // + // TMA parameter checking + // + + CUTE_STATIC_ASSERT_V(product_each(shape(slayout)) == product_each(shape(cta_v_map)), + "TMA requires CTA_Tile and SLayout top-level shape equivalence."); + + // + // TMA slayout manipulation + // + + // Invert the smem to get the largest contiguous vector in the smem layout + auto inv_smem_layout = right_inverse(get_nonswizzle_portion(slayout)); + // trunc_smem_idx -> trunc_smem_coord + + // Map from smem idx to a gmem mode + auto sidx_to_gmode = coalesce(composition(cta_v_map, inv_smem_layout)); + +#if 0 + print("g_layout : "); print(gtensor.layout()); print("\n"); + print("s_layout : "); print(slayout); print("\n"); + print("cta_t_map : "); print(cta_t_map); print("\n"); + print("cta_v_map : "); print(cta_v_map); print("\n"); + print("inv_smem : "); print(inv_smem_layout); print("\n"); + print("sidx_to_gmode : "); print(sidx_to_gmode); print("\n"); +#endif + + // + // TMA gtensor manipulation + // + + // Generate a TupleBasis for the gtensor + auto glayout_basis = make_identity_layout(product_each(shape(gtensor))); + + // Tile the modes of gtensor with the truncated cta_v_map o inv_smem_layout_trunc + auto tma_layout_full = flatten(composition(glayout_basis, sidx_to_gmode)); + + // Truncate any incompatibilities -- no starting in the middle of gmodes + auto smem_rank = find_if(stride(tma_layout_full), [](auto e) { + [[maybe_unused]] auto v = basis_value(e); + return not is_constant<1,decltype(v)>{}; + }); + static_assert(smem_rank >= 2, "IM2COL expects at least 2 modes of the smem to vectorize with gmem."); + // IM2COL uses a maximum of 2 modes + constexpr int smem_tma_rank = cute::min(int(smem_rank), 2); + + // Keep only the static-1 basis modes into gmem + auto tma_layout_trunc = take<0,smem_tma_rank>(tma_layout_full); + + // Split according to the portion each multicast CTA will be responsible for + auto tma_layout_vt = logical_divide(tma_layout_trunc, shape_div(size(tma_layout_trunc), num_multicast)); + +#if 0 + print("glayout_basis : "); print(glayout_basis); print("\n"); + print("tma_layout_full : "); print(tma_layout_full); print("\n"); + + print("tma_layout_trunc: "); print(tma_layout_trunc); print("\n"); + print("tma_layout_vt : "); print(tma_layout_vt); print("\n"); +#endif + + auto range_c = size<0,0>(tma_layout_vt); + auto range_whdn = size<0,1>(tma_layout_vt); + Tensor gtensor_cwhdn = make_tensor(gtensor.data(), + flatten(make_layout(make_layout(basis_get(stride<0,0>(tma_layout_vt), gtensor.shape()), + basis_get(stride<0,0>(tma_layout_vt), gtensor.stride())), + make_layout(basis_get(stride<0,1>(tma_layout_vt), gtensor.shape()), + basis_get(stride<0,1>(tma_layout_vt), gtensor.stride()))))); + auto [tma_desc, tma_tensor] = make_im2col_tma_copy_desc( + gtensor_cwhdn, + range_c, + range_whdn, + detail::get_swizzle_portion(slayout), + tma_layout_vt, + lower_corner_whd, + upper_corner_whd, + lower_padding_whd, + upper_padding_whd, + stride_whd, + lower_srt, + stride_srt, + aux_params); + + // + // Construct the Copy_Traits + // + + using T = typename GEngine::value_type; + constexpr int num_bits_per_tma = decltype(size(tma_layout_trunc))::value * sizeof(T) * 8; + + using Traits = Copy_Traits, decltype(tma_tensor)>; + using Atom = Copy_Atom; + +#if 0 + print("num_bits : "); print(num_bits_per_tma); print("\n"); +#endif + + Traits tma_traits{tma_desc, tma_tensor}; + + // Return the Copy_Atom + return Atom{tma_traits}; +} + +/// Make a TiledCopy for im2col TMA load. +/// +/// @param copy_op The copy implementation: either +/// SM90_TMA_LOAD_IM2COL or SM90_TMA_LOAD_IM2COL_MULTICAST. +/// +/// @param tensor_cwhdn The global tensor to use for im2col TMA loads. +/// For Fprop convolutions, this is the activation tensor. This is +/// the "original tensor that points to global memory, not the +/// coordinate (im2col-transformed) tensor. +/// +/// @param slayout Layout of shared memory tile. +/// +/// @param stride_whd The traversal strides convolution +/// parameter. +/// +/// @return TiledCopy specialization for im2col TMA loads. +template +CUTE_HOST_RTC +auto +make_tma_copy_im2col(CopyOp const& copy_op, + Tensor const& gtensor, + SLayout const& slayout, + Layout const& cta_t_map, // CTA tid -> logical TMA tid + Layout const& cta_v_map, // CTA vid -> gmem coord + LowerCornerStride const& lower_corner_whd, + UpperCornerStride const& upper_corner_whd, + LowerPaddingStride const& lower_padding_whd, + UpperPaddingStride const& upper_padding_whd, + TraversalStride const& stride_whd, // traversal stride + LowerSRTStride const& lower_srt, + DilationStride const& stride_srt, // dilation + TMA::DescriptorAuxParams const& aux_params = {}) +{ + // + // TMA parameter checking + // + + CUTE_STATIC_ASSERT_V(size(slayout) % cosize(cta_t_map) == Int<0>{}, + "Number of active CTAs in TMA must divide domain size of slayout."); + + Copy_Atom atom = make_tma_atom_im2col(copy_op, gtensor, slayout, cosize(cta_t_map), cta_v_map, + lower_corner_whd, upper_corner_whd, lower_padding_whd, + upper_padding_whd, stride_whd, lower_srt, stride_srt, aux_params); + + // + // Construct the TiledCopy + // + + auto cta_tiler = product_each(shape(cta_v_map)); + + auto num_elems_per_tma = size<1>(typename decltype(atom)::RefLayout{}) / static_value>(); + + // smem idx -> smem coord + auto inv_smem_layout = right_inverse(get_nonswizzle_portion(slayout)); + // CTA V -> smem_coord + auto layout_v = composition(inv_smem_layout, num_elems_per_tma); + // Scale that up to cover all of the smem_coords + auto layout_V = tile_to_shape(make_layout(layout_v), size(cta_v_map)); + // CTA T -> smem idx + auto layout_t = make_layout(cosize(cta_t_map), shape_div(num_elems_per_tma, cosize(cta_t_map))); + // CTA TID -> smem coord + auto layout_T = composition(inv_smem_layout, composition(layout_t, cta_t_map)); + // Combine with the T mapping + [[maybe_unused]] auto layout_TV = make_layout(layout_T, layout_V); + +#if 0 + print("cta_tiler : "); print(cta_tiler); print("\n"); + print("layout_v : "); print(layout_v); print("\n"); + print("layout_V : "); print(layout_V); print("\n"); + print("layout_t : "); print(layout_t); print("\n"); + print("layout_T : "); print(layout_T); print("\n"); + print("layout_TV : "); print(layout_TV); print("\n"); +#endif + + return TiledCopy{atom}; +} + +/// Make a TiledCopy for im2col TMA with no offsets. +/// E.g. im2col TMA load for C and im2col TMA store for D. +template +CUTE_HOST_RTC +auto +make_tma_copy_im2col(CopyOp const& copy_op, + Tensor const& gtensor, + SLayout const& slayout, + Layout const& cta_t_map, // CTA tid -> logical TMA tid + Layout const& cta_v_map) // CTA vid -> gmem coord +{ + constexpr int num_spatial_modes = rank<0>(GLayout{}) - 1; + return make_tma_copy_im2col(copy_op, gtensor, slayout, cta_t_map, cta_v_map, + append(Stride<_0>{}, Int<0>{}), // lower_corner_whd + append(Stride<_0>{}, Int<0>{}), // upper_corner_whd + append(Stride<_0>{}, Int<0>{}), // lower_padding_whd + append(Stride<_0>{}, Int<0>{}), // upper_padding_whd + append(Stride<_1>{}, Int<1>{}), // stride_whd + append(Stride<_0>{}, Int<0>{}), // lower_srt + append(Stride<_1>{}, Int<1>{})); // stride_srt +} + +} // namespace detail + + + +template +CUTE_HOST_RTC +auto +make_im2col_tma_copy(CopyOp const& copy_op, + Tensor const& tensor_cwhdn, + SLayout const& slayout, + CTATiler const& cta_tiler, + MulticastSize const& multicast_size, + LowerCornerStride const& lower_corner_whd, + UpperCornerStride const& upper_corner_whd, + LowerPaddingStride const& lower_padding_whd, + UpperPaddingStride const& upper_padding_whd, + TraversalStride const& stride_whd, + LowerSRTStride const& lower_srt, + DilationStride const& stride_srt) +{ + auto cta_v_tile = make_identity_layout(product_each(shape(tensor_cwhdn))).compose(cta_tiler); + auto cta_t_tile = make_layout(multicast_size); + + return detail::make_tma_copy_im2col(copy_op, tensor_cwhdn, + slayout, cta_t_tile, cta_v_tile, + lower_corner_whd, upper_corner_whd, lower_padding_whd, upper_padding_whd, stride_whd, lower_srt, stride_srt); +} + +// Explicit default for multicast_size +template +CUTE_HOST_RTC +auto +make_im2col_tma_copy(CopyOp const& copy_op, + Tensor const& tensor_cwhdn, + SLayout const& slayout, + CTATiler const& cta_tiler, + LowerCornerStride const& lower_corner_whd, + UpperCornerStride const& upper_corner_whd, + LowerPaddingStride const& lower_padding_whd, + UpperPaddingStride const& upper_padding_whd, + TraversalStride const& stride_whd, + LowerSRTStride const& lower_srt, + DilationStride const& stride_srt) +{ + return make_im2col_tma_copy(copy_op, tensor_cwhdn, slayout, cta_tiler, Int<1>{}, + lower_corner_whd, upper_corner_whd, lower_padding_whd, upper_padding_whd, stride_whd, lower_srt, stride_srt); +} + +// Explicit default for cta_tiler and multicast_size +template +CUTE_HOST_RTC +auto +make_im2col_tma_copy(CopyOp const& copy_op, + Tensor const& tensor_cwhdn, + SLayout const& slayout, + LowerCornerStride const& lower_corner_whd, + UpperCornerStride const& upper_corner_whd, + LowerPaddingStride const& lower_padding_whd, + UpperPaddingStride const& upper_padding_whd, + TraversalStride const& stride_whd, + LowerSRTStride const& lower_srt, + DilationStride const& stride_srt) +{ + return make_im2col_tma_copy(copy_op, tensor_cwhdn, slayout, product_each(shape(slayout)), Int<1>{}, + lower_corner_whd, upper_corner_whd, lower_padding_whd, upper_padding_whd, stride_whd, lower_srt, stride_srt); +} + +// No offsets copy. +template +CUTE_HOST_RTC +auto +make_im2col_tma_copy(CopyOp const& copy_op, + Tensor const& tensor_cwhdn, + SLayout const& slayout, + CTATiler const& cta_tiler, + MulticastSize const& multicast_size) +{ + auto cta_v_tile = make_identity_layout(product_each(shape(tensor_cwhdn))).compose(cta_tiler); + auto cta_t_tile = make_layout(multicast_size); + + return detail::make_tma_copy_im2col(copy_op, tensor_cwhdn, slayout, cta_t_tile, cta_v_tile); +} + +// Explicit default for multicast_size +template +CUTE_HOST_RTC +auto +make_im2col_tma_copy(CopyOp const& copy_op, + Tensor const& tensor_cwhdn, + SLayout const& slayout, + CTATiler const& cta_tiler) +{ + return make_im2col_tma_copy(copy_op, tensor_cwhdn, slayout, cta_tiler, Int<1>{}); +} + +// Explicit default for cta_tiler and multicast_size +template +CUTE_HOST_RTC +auto +make_im2col_tma_copy(CopyOp const& copy_op, + Tensor const& tensor_cwhdn, + SLayout const& slayout) +{ + return make_im2col_tma_copy(copy_op, tensor_cwhdn, slayout, product_each(shape(slayout)), Int<1>{}); +} + +} // namespace cute diff --git a/include/cute/atom/copy_traits_sm90_tma.hpp b/include/cute/atom/copy_traits_sm90_tma.hpp new file mode 100644 index 0000000000..4ad7f80851 --- /dev/null +++ b/include/cute/atom/copy_traits_sm90_tma.hpp @@ -0,0 +1,1560 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#if !defined(__CUDACC_RTC__) +#include +#endif + +#include +#include +#include + +#include + +#include + +#include + +namespace cute +{ + +template +struct AuxTmaParams { + using GmemStrides = GmemTmaBasisStrides_; // Strides for Gmem mode -> Tma coord mode, may be dynamic + GmemStrides g_stride_; + using TmaGmemBasis = TmaGmemBasis_; // Layout for Tma box shape -> Gmem mode(s), always static + static_assert(is_static::value); + using TmaSwizzle = TmaSwizzle_; // Tma swizzle, always Swizzle + static_assert(is_static::value); +}; + +// Utility for unpacking TMA_LOAD arguments into a CopyOp +template +struct TMA_LOAD_Unpack +{ + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + static_assert(is_smem::value, "SM90_TMA_LOAD requires the destination be shared memory."); + + auto src_coord = src.data().coord_; + void* dst_ptr = cute::raw_pointer_cast(dst.data()); +#if 0 + auto [c0,c1,c2,c3,c4] = append<5>(src_coord, 0); + printf("THR (%d,%d,%d) BLK (%d,%d,%d) TMACRD (%d,%d,%d,%d,%d) SMEMADDR (%p)\n", + threadIdx.x, threadIdx.y, threadIdx.z, + blockIdx.x, blockIdx.y, blockIdx.z, + int32_t(c0), int32_t(c1), int32_t(c2), int32_t(c3), int32_t(c4), dst_ptr); +#endif + return detail::explode_tuple(detail::CallCOPY{}, + traits.opargs_, tuple_seq{}, + make_tuple(dst_ptr), seq<0>{}, + src_coord, tuple_seq{}); + } +}; + +////////////////////////////////////////////////////////////////////////////// +///////////////////////////// TMA_LOAD /////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////// + +struct SM90_TMA_LOAD_OP : SM90_TMA_LOAD {}; + +// The non-executable SM90_TMA_LOAD with tma_desc and no tma_mbar +// Use .with(tma_mbar) to construct an executable version +template +struct Copy_Traits +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_TMA_LOAD arguments + TmaDescriptor tma_desc_; + using AuxParams = AuxParams_; + AuxParams aux_params_; + + // Return TmaDescriptor/TensorMap + CUTE_HOST_DEVICE constexpr + TmaDescriptor const* + get_tma_descriptor() const { + return &tma_desc_; + } + + // Construct an executable SM90_TMA_LOAD with tma_mbar + CUTE_HOST_DEVICE constexpr + Copy_Traits + with( + uint64_t& tma_mbar, + [[maybe_unused]] uint16_t const& multicast_mask = 0, + TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const { + // We accept multicast_mask here to keep the API for both atoms consistent + return {&tma_desc_, &tma_mbar, static_cast(cache_hint)}; + } + + // Construct an executable SM90_TMA_LOAD with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm) + CUTE_HOST_DEVICE constexpr + Copy_Traits + with( + TmaDescriptor const* new_tma_desc, + uint64_t& tma_mbar, + [[maybe_unused]] uint16_t const& multicast_mask = 0, + TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const { + // We accept multicast_mask here to keep the API for both atoms consistent + return {new_tma_desc, &tma_mbar, static_cast(cache_hint)}; + } + + // Generate the TMA coord tensor + template + CUTE_HOST_DEVICE constexpr + auto + get_tma_tensor(GShape const& g_shape) const { + static_assert(is_congruent::value); + return make_counting_tensor(make_layout(g_shape, aux_params_.g_stride_)); + } + + // Don't try to execute a copy with SM90_TMA_LOAD before calling .with() + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) = delete; +}; + +// The executable SM90_TMA_LOAD with tma_desc and tma_mbar +template +struct Copy_Traits + : TMA_LOAD_Unpack +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_TMA_LOAD arguments + tuple< + TmaDescriptor const*, + uint64_t*, // smem mbarrier + uint64_t // cache hint + > const opargs_; + + CUTE_HOST_DEVICE + Copy_Traits(TmaDescriptor const* desc, uint64_t* mbar, uint64_t cache) + : opargs_(desc, mbar, cache) {} +}; + +// The prefetch for SM90_TMA_LOAD with tma_desc +template +struct Copy_Traits +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_TMA_LOAD::PREFETCH arguments + tuple const opargs_; + + // Construct with any other Traits' TMA Desc + template + CUTE_HOST_DEVICE + Copy_Traits(Copy_Traits const& traits) + : opargs_({&traits.tma_desc_}) {} + + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + auto src_coord = src.data().coord_; + return detail::explode_tuple(detail::CallCOPY{}, + traits.opargs_, tuple_seq{}, + src_coord, tuple_seq{}); + } +}; + +////////////////////////////////////////////////////////////////////////////// +///////////////////////////// TMA_LOAD_MULTICAST ///////////////////////////// +////////////////////////////////////////////////////////////////////////////// + +struct SM90_TMA_LOAD_MULTICAST_OP : SM90_TMA_LOAD_MULTICAST {}; + +// The non-executable SM90_TMA_LOAD_MULTICAST with tma_desc and no tma_mbar +// Use .with(tma_mbar, multicast_mask) to construct an executable version +template +struct Copy_Traits +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_TMA_LOAD_MULTICAST arguments + TmaDescriptor tma_desc_; + using AuxParams = AuxParams_; + AuxParams aux_params_; + + // Return TmaDescriptor/TensorMap + CUTE_HOST_DEVICE constexpr + TmaDescriptor const* + get_tma_descriptor() const { + return &tma_desc_; + } + + // Construct an executable SM90_TMA_LOAD_MULTICAST with tma_mbar + CUTE_HOST_DEVICE constexpr + Copy_Traits + with( + uint64_t& tma_load_mbar, + uint16_t const& multicast_mask, + TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const { + return {&tma_desc_, &tma_load_mbar, multicast_mask, static_cast(cache_hint)}; + } + + // Construct an executable SM90_TMA_LOAD_MULTICAST_OP with tma_mbar (temp. overloaded for grouped gemm/ptr array gemm) + CUTE_HOST_DEVICE constexpr + Copy_Traits + with( + TmaDescriptor const* new_tma_desc, + uint64_t& tma_load_mbar, + uint16_t const& multicast_mask, + TMA::CacheHintSm90 const& cache_hint = TMA::CacheHintSm90::EVICT_NORMAL) const { + return {new_tma_desc, &tma_load_mbar, multicast_mask, static_cast(cache_hint)}; + } + + // Generate the TMA coord tensor + template + CUTE_HOST_DEVICE constexpr + auto + get_tma_tensor(GShape const& g_shape) const { + static_assert(is_congruent::value); + return make_counting_tensor(make_layout(g_shape, aux_params_.g_stride_)); + } + + // Don't try to execute a copy with SM90_TMA_LOAD_MULTICAST before calling .with() + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) = delete; +}; + +// The executable SM90_TMA_LOAD_MULTICAST with tma_desc and tma_mbar and multicast_mask +template +struct Copy_Traits + : TMA_LOAD_Unpack +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_TMA_LOAD_MULTICAST arguments + tuple< + TmaDescriptor const*, + uint64_t*, // smem mbarrier + uint16_t, // multicast mask + uint64_t // cache hint + > const opargs_; + + CUTE_HOST_DEVICE + Copy_Traits(TmaDescriptor const* desc, uint64_t* mbar, uint16_t mask, uint64_t hint) + : opargs_(desc, mbar, mask, hint) {} +}; + +////////////////////////////////////////////////////////////////////////////// +///////////////////////////// TMA_STORE ////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////// + +struct SM90_TMA_STORE_PTR : SM90_TMA_STORE {}; + +// The executable SM90_TMA_STORE with tma_desc +template +struct Copy_Traits +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_TMA_STORE arguments + TmaDescriptor tma_desc_; + using AuxParams = AuxParams_; + AuxParams aux_params_; + + // Return TmaDescriptor/TensorMap + CUTE_HOST_DEVICE constexpr + TmaDescriptor const* + get_tma_descriptor() const { + return &tma_desc_; + } + + // Generate the TMA coord tensor + template + CUTE_HOST_DEVICE constexpr + auto + get_tma_tensor(GShape const& g_shape) const { + static_assert(is_congruent::value); + return make_counting_tensor(make_layout(g_shape, aux_params_.g_stride_)); + } + + // Construct new TMA_STORE with (unsafe) swapped out TMA descriptor ptr (for grouped gemm/ptr array gemm) + CUTE_HOST_DEVICE constexpr + Copy_Traits + with(TmaDescriptor const* new_tma_desc) const { + return {new_tma_desc}; + } + + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + static_assert(is_smem::value, "Expected smem src for SM90_TMA_STORE"); + //static_assert(is_gmem::value, "Expected gmem dst for SM90_TMA_STORE"); // TMA spoofed src tensor + + void const* const desc_ptr = &(traits.tma_desc_); + void const* const src_ptr = cute::raw_pointer_cast(src.data()); + auto dst_coord = dst.data().coord_; +#if 0 + auto [c0,c1,c2,c3,c4] = append<5>(dst_coord, 0); + printf("THR (%d,%d,%d) BLK (%d,%d,%d) TMACRD (%d,%d,%d,%d,%d) SMEMADDR (%p)\n", + threadIdx.x, threadIdx.y, threadIdx.z, + blockIdx.x, blockIdx.y, blockIdx.z, + int32_t(c0), int32_t(c1), int32_t(c2), int32_t(c3), int32_t(c4), src_ptr); +#endif + return detail::explode_tuple(detail::CallCOPY{}, + make_tuple(desc_ptr, src_ptr), seq<0,1>{}, + dst_coord, tuple_seq{}); + } +}; + +// Same as SM90_TMA_STORE, but with an unsafe TMA Desc PTR instead +template +struct Copy_Traits +{ + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_TMA_STORE arguments + TmaDescriptor const* tma_desc_; + + template + CUTE_HOST_DEVICE friend constexpr void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + static_assert(is_smem::value, "Expected smem src for SM90_TMA_STORE"); + //static_assert(is_gmem::value, "Expected gmem dst for SM90_TMA_STORE"); // TMA spoofed src tensor + + void const* const desc_ptr = traits.tma_desc_; + void const* const src_ptr = cute::raw_pointer_cast(src.data()); + auto dst_coord = dst.data().coord_; +#if 0 + auto [c0,c1,c2,c3,c4] = append<5>(dst_coord, 0); + printf("THR (%d,%d,%d) BLK (%d,%d,%d) TMACRD (%d,%d,%d,%d,%d) SMEMADDR (%p)\n", + threadIdx.x, threadIdx.y, threadIdx.z, + blockIdx.x, blockIdx.y, blockIdx.z, + int32_t(c0), int32_t(c1), int32_t(c2), int32_t(c3), int32_t(c4), src_ptr); +#endif + return detail::explode_tuple(detail::CallCOPY{}, + make_tuple(desc_ptr, src_ptr), seq<0,1>{}, + dst_coord, tuple_seq{}); + } +}; + +////////////////////////////////////////////////////////////////////////////// +///////////////////////////// TMA_REDUCE_ADD ////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////// + +// The executable SM90_TMA_REDUCE_ADD with tma_desc +template +struct Copy_Traits +{ + using ThrID = Layout<_1>; + + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_TMA_REDUCE_ADD arguments + TmaDescriptor tma_desc_; + using AuxParams = AuxParams_; + AuxParams aux_params_; + + // Return TmaDescriptor/TensorMap + CUTE_HOST_DEVICE constexpr + TmaDescriptor const* + get_tma_descriptor() const { + return &tma_desc_; + } + + // Generate the TMA coord tensor + template + CUTE_HOST_DEVICE constexpr + auto + get_tma_tensor(GShape const& g_shape) const { + static_assert(is_congruent::value); + return make_counting_tensor(make_layout(g_shape, aux_params_.g_stride_)); + } + + template + CUTE_HOST_DEVICE constexpr + void + copy_unpack_(void const* const src_ptr, + Coord const& dst_coord, seq) const + { +#if 0 + auto [c0,c1,c2,c3,c4] = append<5>(dst_coord, 0); + printf("THR (%d,%d,%d) BLK (%d,%d,%d) TMACRD (%d,%d,%d,%d,%d) SMEMADDR (%p)\n", + threadIdx.x, threadIdx.y, threadIdx.z, + blockIdx.x, blockIdx.y, blockIdx.z, + int32_t(c0), int32_t(c1), int32_t(c2), int32_t(c3), int32_t(c4), src_ptr); +#endif + + SM90_TMA_REDUCE_ADD::copy(&tma_desc_, + src_ptr, get(dst_coord)...); + } + + // This is the copy_unpack dispatch for this Copy_Traits + // Src needs to be a smem tensor + // Dst needs to be a gmem tensor with TmaCoordIterator .data() + template + CUTE_HOST_DEVICE friend constexpr + void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + static_assert(is_smem::value, "Expected smem src for SM90_TMA_REDUCE_ADD"); + //static_assert(is_gmem::value, "Expected gmem dst for SM90_TMA_REDUCE_ADD"); // TMA spoofed src tensor + + traits.copy_unpack_(cute::raw_pointer_cast(src.data()), dst.data().coord_, tuple_seq{}); + } +}; + +////////////////////////////////////////////////////////////////////////////// +///////////////////////////// BULK COPY ////////////////////////////////////// +////////////////////////////////////////////////////////////////////////////// + +template +struct Copy_Traits +{ + static_assert(int32_t(NumBitsPerTMA::value / 8) % 16 == 0, + "Bulk Copy requires copy vector size align to 16B."); + + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_BULK_COPY_G2S arguments + // 0: uint64_t* bulk_load_memory_barrier + cute::tuple bulk_load_mbar_; + + // Record the memory barrier for the instruction + CUTE_HOST_DEVICE constexpr + Copy_Traits + with(uint64_t& bulk_mbar) const { + return {&bulk_mbar}; + } + + template + CUTE_HOST_DEVICE friend constexpr + void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + static_assert(is_same, cute::tuple>::value, + "Extra arguments not set. Set .with() before use."); + static_assert(is_gmem::value, "Expected gmem src for SM90_BULK_COPY_G2S"); + static_assert(is_smem::value, "Expected smem dst for SM90_BULK_COPY_G2S"); + SM90_BULK_COPY_G2S::copy(raw_pointer_cast(src.data()), get<0>(traits.bulk_load_mbar_), + raw_pointer_cast(dst.data()), int32_t(NumBitsPerTMA::value / 8)); + } +}; + +template +struct Copy_Traits + : Copy_Traits +{ + template + CUTE_HOST_DEVICE + Copy_Traits(Copy_Traits const& traits) {} + + template + CUTE_HOST_DEVICE friend constexpr + void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + static_assert(is_gmem::value, "Expected gmem src for SM90_BULK_PREFETCH"); + SM90_BULK_COPY_G2S::PREFETCH::copy(raw_pointer_cast(src.data()), int32_t(NumBitsPerTMA::value / 8)); + } +}; + +template +struct Copy_Traits +{ + static_assert(int32_t(NumBitsPerTMA::value / 8) % 16 == 0, + "Bulk Copy requires copy vector size align to 16B."); + + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + template + CUTE_HOST_DEVICE friend constexpr + void + copy_unpack(Copy_Traits const& traits, + Tensor const& src, + Tensor & dst) + { + static_assert(is_smem::value, "Expected smem src for SM90_BULK_COPY_S2G"); + static_assert(is_gmem::value, "Expected gmem dst for SM90_BULK_COPY_S2G"); + SM90_BULK_COPY_S2G::copy(raw_pointer_cast(src.data()), raw_pointer_cast(dst.data()), int32_t(NumBitsPerTMA::value / 8)); + } +}; + +// +// Placeholder for the bulk copy algorithm's default, auto-vectorizing behavior +// + +template +struct Copy_Traits +{ + // Logical thread id to thread idx (one-thread) + using ThrID = Layout<_1>; + // Map from (src-thr,src-val) to bit + using SrcLayout = Layout, Stride<_0,_0>>; + // Map from (dst-thr,dst-val) to bit + using DstLayout = Layout, Stride<_0,_0>>; + // Reference map from (thr,val) to bit + using RefLayout = SrcLayout; + + // SM90_UBULK_COPY arguments + // 0: uint64_t* bulk_load_memory_barrier [if this is a BULK_LOAD_G2S] + cute::tuple opargs_; + + // Record the memory barrier for the instruction + CUTE_HOST_DEVICE constexpr + Copy_Traits + with(uint64_t& bulk_mbar) const { + return {&bulk_mbar}; + } +}; + +// +// MAKE_TMA_COPY and related +// + +namespace detail { + +// Custom version of coalesce that greedily combines modes only up to size-256 +// Look at each element and the back of the stack (in order of priority) +// back(NewLayout) get(OldLayout) +// s0:d0 _1:d1 => continue +// _1:d0 s1:d1 => replace_back s1:d1 +// s0:d0 s1:s0*d0 => replace_back s0*s1:d0 if s0*s1 <= 256 +// s0:d0 s1:d1 => append s1:d1 +// +// @pre OldShape and OldStride are flat +template +CUTE_HOST_DEVICE constexpr +auto +coalesce_256_impl(OldShape const& old_shape, OldStride const& old_stride, + NewShape const& new_shape, NewStride const& new_stride) +{ + if constexpr (I == rank_v) { + // Base case, we're done + if constexpr (is_constant<1, NewShape>::value) { + return Layout<_1,_0>{}; + } else { + return Layout{new_shape,new_stride}; + } + } else if constexpr (is_constant<1, decltype(get(old_shape))>::value) { + // shape(layout) == _1, skip it and continue + return coalesce_256_impl(old_shape, old_stride, new_shape, new_stride); + } else if constexpr (is_constant<1, NewShape>::value) { + // Replace our shape-1 with anything (Can only happen on input new_shape/new_stride) + return coalesce_256_impl(old_shape, old_stride, get(old_shape), get(old_stride)); + } else if constexpr (is_constant(old_stride) && + get(old_shape) * back(new_shape) <= Int<256>{})>::value) { + // Merge modes because the shapes and strides match and the merge is 256 or less + return coalesce_256_impl(old_shape, old_stride, + replace_back(new_shape, get(old_shape) * back(new_shape)), + new_stride); + } else { + // Can't replace or merge, so append a new mode + return coalesce_256_impl(old_shape, old_stride, + append(new_shape, get(old_shape)), + append(new_stride, get(old_stride))); + } + + CUTE_GCC_UNREACHABLE; +} + +// Combine all the modes that are possible to combine +// Does not respect the profile of the layout, but does preserve total size +template +CUTE_HOST_DEVICE constexpr +auto +coalesce_256(Layout const& layout) +{ + auto flat_shape = flatten(layout.shape()); + auto flat_stride = flatten(layout.stride()); + return coalesce_256_impl<1>(flat_shape, flat_stride, get<0>(flat_shape), get<0>(flat_stride)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +construct_tma_gbasis(Tensor const& gtensor, // The original GMEM Tensor + Layout const& slayout, // The layout of SMEM + Layout const& cta_v_map) // smem_idx to hier gmode +{ + // + // TMA parameter checking + // + + // CUTE_STATIC_ASSERT_V(product_each(shape(slayout)) == product_each(shape(cta_v_map)), + // "TMA requires CTA_Tile and SLayout top-level shape equivalence."); + CUTE_STATIC_ASSERT_V(size(slayout) == size(cta_v_map), + "TMA requires CTA_Tile and SLayout top-level size equivalence."); + +#if 0 + print("gtensor : "); print(gtensor); print("\n"); + print("slayout : "); print(slayout); print("\n"); + print("cta_v_map : "); print(cta_v_map); print("\n"); +#endif + + // + // TMA slayout manipulation + // + + // Invert the smem to get the largest contiguous vector in the smem layout + // smem idx -> smem coord + auto inv_smem_layout = right_inverse(get_nonswizzle_portion(slayout)); + + // Compose with the V-Map to convert smem coord (CTA val idx) to gmem mode + // smem idx -> gmem mode + auto sidx2gmode_full = coalesce(composition(cta_v_map, inv_smem_layout)); + +#if 0 + print("inv_smem_layout : "); print(inv_smem_layout); print("\n"); + print("sidx2gmode_full : "); print(sidx2gmode_full); print("\n"); +#endif + + // + // TMA gtensor truncation + // + + // Truncate any incompatibilities -- no starting in the middle of gmodes + auto smem_rank = find_if(stride(sidx2gmode_full), [](auto e) { + [[maybe_unused]] auto v = basis_value(e); + return not is_constant<1,decltype(v)>{}; + }); + static_assert(smem_rank > 0, "Could not find a common tile-gmem vectorization. Does the Tile select out major GMEM modes?"); + + // Keep only the static-1 basis modes into gmem + auto sidx2gmode = take<0,smem_rank>(sidx2gmode_full); + +#if 0 + print("smem_rank : "); print(smem_rank); print("\n"); + print("sidx2gmode : "); print(sidx2gmode); print("\n"); +#endif + + // + // TMA gtensor manipulation + // + + // The smem vector is the same units as gtensor, so compose first and then recast + // tma_val_idx:gmem_strides + auto tile_gstride = recast(gtensor.compose(sidx2gmode)).layout(); + // Coalesce modes up to size-256 (the maximum TMA box extent in units of TmaInternalType) + // tma_box_shape:gmem_strides + auto tma_gstride = coalesce_256(tile_gstride); + + // Perform the tiling, recast, and coalesce to the gmem vector again, but with indirections to the gtensor modes + auto gbasis = make_identity_layout(shape(gtensor)); + auto tile_gbasis_tmp = gbasis.compose(sidx2gmode); + + // Instead of the recast (gbasis doesn't have type info), replace the shape with the already-recasted shape + // tma_box_shape:gmem_mode + auto tile_gbasis = make_layout(shape(tile_gstride), stride(tile_gbasis_tmp)); + + // "Coalesce" the tile basis into a compatible shape with the tma_gstride + auto tma_gbasis_tile = tile_gbasis.compose(make_layout(wrap(shape(tma_gstride)))); + + // Recast the original tensor for shape/stride inspections + Tensor gtensor_T = recast(gtensor); + + // Find missing bases that don't appear in tile_gbasis + auto tile_gbasis_remaining_stride = filter_tuple(flatten(shape (gtensor_T)), flatten(stride(gtensor_T)), + flatten(stride(gbasis)), + [&](auto s, auto d, auto e) + { + if constexpr (is_constant<1, decltype(s)>::value || is_constant<0, decltype(d)>::value) { + return cute::tuple<>{}; // If size-1 or stride-0, then don't append + } else { + using E = decltype(e); + auto has_e = any_of(flatten(stride(tma_gbasis_tile)), [] (auto tb) { return tb == E{}; }); + if constexpr (decltype(has_e)::value) { + return cute::tuple<>{}; // If d was found, then don't append + } else { + return cute::tuple(e); // Else, this is missing so append + } + } + }); + + // Append the remaining basis modes that contribute to the TMA with size-1 + auto tile_gbasis_remaining_shape = repeat(Int<1>{}); + auto tma_gbasis_full = make_layout(tuple_cat(wrap( shape(tma_gbasis_tile)), wrap(tile_gbasis_remaining_shape )), + tuple_cat(wrap(stride(tma_gbasis_tile)), wrap(tile_gbasis_remaining_stride))); + + // Group the trailing modes to make this max rank-5 -- TMA rank limitation + // tma_box_shape:gmem_mode + auto tma_gbasis = group(tma_gbasis_full); + +#if 0 + print("tile_gstride : "); print(tile_gstride); print("\n"); + print("tma_gstride : "); print(tma_gstride); print("\n"); + print("gbasis : "); print(gbasis); print("\n"); + print("tile_gbasis : "); print(tma_gbasis_tile); print("\n"); + print("tma_gbasis : "); print(tma_gbasis); print("\n"); +#endif + + return tma_gbasis; +} + +template +CUTE_HOST_DEVICE constexpr +void +fill_tma_gmem_shape_stride(Tensor const& gtensor, // Gmem Shapes and Strides, in units of TmaInternalType + TmaGmemBasisStride const& tma_gbasis_stride, // Map Tma mode idx -> Gmem mode(s) + cute::array & gmem_prob_shape, // Tma Shapes, uint32_t or uin64_t + cute::array & gmem_prob_stride) // Tma Strides +{ + static_assert(is_tuple::value); + static_assert(is_same::value || is_same::value); + + using TmaInternalType = typename GEngine::value_type; + constexpr int tma_rank = decltype(rank(tma_gbasis_stride))::value; + static_assert(TmaRank >= tma_rank); + + auto gmem_shape = shape(gtensor); + auto gmem_stride = stride(gtensor); + // Use the indirections in tma_gbasis_stride into gtensor to construct the tma gmem shapes/strides + for_each(make_seq{}, [&](auto i) { + constexpr int tma_i_rank = decltype(rank(tma_gbasis_stride))::value; + if constexpr (tma_i_rank == 1) { + // Trivial contribution of this gmem mode to this tma mode + auto ej = unwrap(get(tma_gbasis_stride)); + gmem_prob_shape[i] = basis_get(ej, gmem_shape); + gmem_prob_stride[i] = basis_get(ej, gmem_stride); + } else { + // Apply a recurrence to each gmem mode that contributes to this tma mode + for_each(get(tma_gbasis_stride), [&](auto ej) { + // Problem shape + uint64_t shape_j = basis_get(ej, gmem_shape); + // Problem stride (in bytes) + uint64_t stride_j = basis_get(ej, gmem_stride); + uint64_t old_stride = gmem_prob_stride[i]; + gmem_prob_stride[i] = gcd(gmem_prob_stride[i], stride_j); + + if (gmem_prob_stride[i] != 0) { + // Recurrence: g_shape = (s_i - 1) * (d_i / gcd_j d_j) + 1 + gmem_prob_shape[i] = (gmem_prob_shape[i]-1) * (old_stride / gmem_prob_stride[i]) + + (shape_j-1) * (stride_j / gmem_prob_stride[i]) + + 1; + } else { + gmem_prob_shape[i] = shape_j; + } + }); + } + }); +} + +// Overload for an existing Copy_Traits +template +CUTE_HOST_DEVICE constexpr +void +fill_tma_gmem_shape_stride(Copy_Traits const& tma_traits, + Tensor const& gtensor, // Gmem Shapes and Strides, value_type = TmaInternalType + cute::array & gmem_prob_shape, // Tma Shapes, uint32_t or uin64_t + cute::array & gmem_prob_stride) // Tma Strides +{ + return fill_tma_gmem_shape_stride(gtensor, stride(typename Aux::TmaGmemBasis{}), + gmem_prob_shape, gmem_prob_stride); +} + +// Use a sidx2gmode to read through the GMEM tensor +// and construct a TMA Descriptor for the resulting instruction +// At the same time, construct the Tma Tensor's Stride to generate +// the TMA coordinates that the instruction consumes. +// +template +CUTE_HOST_RTC +auto +make_tma_copy_desc(Tensor const& gtensor, // The original GMEM Tensor + Layout const& tma_gbasis, // TMA mode -> GMEM mode mapping + Swizzle const& swizzle, // Swizzle fn on smem_idx + uint32_t num_multicast) // The number of CTAs in multicasting +{ + // + // TMA desc creation + // + + constexpr int tma_dim = decltype(rank(tma_gbasis))::value; + + // + // TMA gmem desc info + // + + // Recast the original tensor for shape/stride inspections + Tensor gtensor_T = recast(gtensor); + + void* gmem_address = (void*) raw_pointer_cast(gtensor_T.data()); + auto gmem_layout = gtensor_T.layout(); + + cute::array gmem_prob_shape = {1,1,1,1,1}; + cute::array gmem_prob_stride = {0,0,0,0,0}; + + fill_tma_gmem_shape_stride(gtensor_T, stride(tma_gbasis), gmem_prob_shape, gmem_prob_stride); + + assert((reinterpret_cast(gmem_address) & 0b1111) == 0); // Address must be 16B-aligned + + assert(gmem_prob_shape[0] >= (uint64_t(1))); // Size must be min 1 + assert(gmem_prob_shape[0] <= (uint64_t(1) << 32)); // Size must be max 2^32 + assert(gmem_prob_shape[1] >= (uint64_t(1))); // Size must be min 1 + assert(gmem_prob_shape[1] <= (uint64_t(1) << 32)); // Size must be max 2^32 + assert(gmem_prob_shape[2] >= (uint64_t(1))); // Size must be min 1 + assert(gmem_prob_shape[2] <= (uint64_t(1) << 32)); // Size must be max 2^32 + assert(gmem_prob_shape[3] >= (uint64_t(1))); // Size must be min 1 + assert(gmem_prob_shape[3] <= (uint64_t(1) << 32)); // Size must be max 2^32 + assert(gmem_prob_shape[4] >= (uint64_t(1))); // Size must be min 1 + assert(gmem_prob_shape[4] <= (uint64_t(1) << 32)); // Size must be max 2^32 + + // TMA descriptor does not store the zeroth stride and assumes it is 1 (TmaInternalType element). + assert(gmem_prob_stride[0] == 1 && "Majorness of smem doesn't match majorness of gmem"); + + // convert strides to byte strides + for(uint64_t& stride : gmem_prob_stride) { + stride = (stride * sizeof_bits_v) / 8; + } + + // Assert the byte strides. Tma Descriptor uses byte strides + assert((gmem_prob_stride[1]) < (uint64_t(1) << 40)); // Stride must be max 2^40 + assert((gmem_prob_stride[1] & 0b1111) == 0); // Stride must be multiple of 16B (128b) + assert((gmem_prob_stride[2]) < (uint64_t(1) << 40)); // Stride must be max 2^40 + assert((gmem_prob_stride[2] & 0b1111) == 0); // Stride must be multiple of 16B (128b) + assert((gmem_prob_stride[3]) < (uint64_t(1) << 40)); // Stride must be max 2^40 + assert((gmem_prob_stride[3] & 0b1111) == 0); // Stride must be multiple of 16B (128b) + assert((gmem_prob_stride[4]) < (uint64_t(1) << 40)); // Stride must be max 2^40 + assert((gmem_prob_stride[4] & 0b1111) == 0); // Stride must be multiple of 16B (128b) + + // + // TMA smem desc info + // + + cute::array smem_box_shape = {1,1,1,1,1}; + cute::array smem_box_stride = {1,1,1,1,1}; + // The smem box is simply given by the sizes of the modes in tma_gbasis + for_each(make_seq{}, [&](auto i) { + smem_box_shape[i] *= size(tma_gbasis); + }); + // Finally, truncate the tma box by the num_multicast + for (uint32_t i = tma_dim-1, multicast = num_multicast; multicast > 1; --i) { + assert(smem_box_shape[i] % multicast == 0 || multicast % smem_box_shape[i] == 0); + uint32_t new_mult = ceil_div(multicast, smem_box_shape[i]); + smem_box_shape[i] = ceil_div(smem_box_shape[i], multicast); + multicast = new_mult; + } + + assert(smem_box_shape[0] >= (uint32_t(1))); // Size must be min 1 + assert(smem_box_shape[0] <= (uint32_t(1) << 8)); // Size must be max 2^8 = 256 + assert(smem_box_shape[1] >= (uint32_t(1))); // Size must be min 1 + assert(smem_box_shape[1] <= (uint32_t(1) << 8)); // Size must be max 2^8 = 256 + assert(smem_box_shape[2] >= (uint32_t(1))); // Size must be min 1 + assert(smem_box_shape[2] <= (uint32_t(1) << 8)); // Size must be max 2^8 = 256 + assert(smem_box_shape[3] >= (uint32_t(1))); // Size must be min 1 + assert(smem_box_shape[3] <= (uint32_t(1) << 8)); // Size must be max 2^8 = 256 + assert(smem_box_shape[4] >= (uint32_t(1))); // Size must be min 1 + assert(smem_box_shape[4] <= (uint32_t(1) << 8)); // Size must be max 2^8 = 256 + + assert(smem_box_stride[0] >= (uint32_t(1))); // Stride must be min 1 + assert(smem_box_stride[0] <= (uint32_t(8))); // Stride must be max 2^3 = 8 + assert(smem_box_stride[1] >= (uint32_t(1))); // Stride must be min 1 + assert(smem_box_stride[1] <= (uint32_t(8))); // Stride must be max 2^3 = 8 + assert(smem_box_stride[2] >= (uint32_t(1))); // Stride must be min 1 + assert(smem_box_stride[2] <= (uint32_t(8))); // Stride must be max 2^3 = 8 + assert(smem_box_stride[3] >= (uint32_t(1))); // Stride must be min 1 + assert(smem_box_stride[3] <= (uint32_t(8))); // Stride must be max 2^3 = 8 + assert(smem_box_stride[4] >= (uint32_t(1))); // Stride must be min 1 + assert(smem_box_stride[4] <= (uint32_t(8))); // Stride must be max 2^3 = 8 + + // + // Construct the descriptor + // + + TmaDescriptor tma_desc{}; + + // + // TMA general info + // + + #if (__CUDACC_VER_MAJOR__ >= 12) && !defined(__CUDACC_RTC__) + + CUtensorMapDataType tma_format = TMA::to_CUtensorMapDataType(); + CUtensorMapInterleave tma_interleave = CU_TENSOR_MAP_INTERLEAVE_NONE; + CUtensorMapL2promotion tma_l2Promotion = CU_TENSOR_MAP_L2_PROMOTION_L2_128B; + CUtensorMapFloatOOBfill tma_oobFill = CU_TENSOR_MAP_FLOAT_OOB_FILL_NONE; + + // TMA smem swizzle type + TMA::SmemSwizzleBits swizzle_bits = get_tma_swizzle_bits(swizzle); + TMA::SmemSwizzleBase swizzle_base = get_tma_swizzle_base(swizzle); + CUtensorMapSwizzle smem_swizzle = TMA::to_CUtensorMapSwizzle(swizzle_bits, swizzle_base); + CUresult result = CUTLASS_CUDA_DRIVER_WRAPPER_CALL(cuTensorMapEncodeTiled)( + &tma_desc, + tma_format, + tma_dim, + gmem_address, + gmem_prob_shape.data(), + gmem_prob_stride.data() + 1, // gmem_prob_stride[0] implicitly 1 + smem_box_shape.data(), + smem_box_stride.data(), + tma_interleave, + smem_swizzle, + tma_l2Promotion, + tma_oobFill); + + if (result != CUDA_SUCCESS) { + std::cerr << "TMA Desc Addr: " << &tma_desc + << "\nformat " << tma_format + << "\ndim " << tma_dim + << "\ngmem_address " << gmem_address + << "\nglobalDim " << gmem_prob_shape + << "\nglobalStrides " << gmem_prob_stride + << "\nboxDim " << smem_box_shape + << "\nelementStrides " << smem_box_stride + << "\ninterleave " << tma_interleave + << "\nswizzle " << smem_swizzle + << "\nl2Promotion " << tma_l2Promotion + << "\noobFill " << tma_oobFill << std::endl; + std::cerr << "Error: Failed to initialize the TMA descriptor " << result << std::endl; + assert(false); + } + + #endif // (__CUDACC_VER_MAJOR__ >= 12) && !defined(__CUDACC_RTC__) + auto recast_ratio = cute::trait_ratio(sizeof_bits{}, + sizeof_bits< TmaInternalType>{}); + + auto gbasis = make_basis_like(shape(gtensor)); + + // Finally, get the inverse permutation of the E bases for the mocked gmem stride + auto gmem_tma_basis_stride = transform_leaf(gbasis, [&](auto ei) { + auto si = basis_get(ei, shape(gmem_layout)); + auto di = basis_get(ei, stride(gmem_layout)); + if constexpr (is_constant<1, decltype(si)>::value || is_constant<0, decltype(di)>::value) { + return Int<0>{}; // If size-1 or stride-0, return arithmetic identity -- no contribution to the TMA + } else { + auto tma_gmem_basis_stride = stride(tma_gbasis); + // Find j such that E is in stride(tma_gbasis) + using EI = decltype(ei); + [[maybe_unused]] auto j = find_if(tma_gmem_basis_stride, [&](auto tma_stride_j) { return any_of(tma_stride_j, [&](auto dj) { return dj == EI{}; }); }); + if constexpr (decltype(j == rank(tma_gmem_basis_stride))::value) { + return Int<0>{}; // If not-found, return arithmetic identity -- no contribution to the TMA + } else + if constexpr (decltype(j == Int<0>{})::value) { + auto scale = recast_ratio * basis_get(ei, stride(gtensor)); + return E{} * scale; // Return TMA Coord basis -- with a recast scale factor + } else + if constexpr (decltype(rank(tma_gmem_basis_stride) == Int<1>{})::value) { + return E{}; // Return TMA Coord basis -- known scale of Int<1>{} + } else { + int32_t scale = ceil_div(int32_t(di * sizeof_bits_v / cute::max(gmem_prob_stride[j], uint64_t{16})), 8); + return E{} * scale; // Return TMA Coord basis -- with a dynamic scale factor + } + } + }); + +#if 0 + print("gmem_tma_basis_stride : "); print(gmem_tma_basis_stride); print("\n"); +#endif + + using AuxParams = AuxTmaParams; + return cute::make_tuple(tma_desc, AuxParams{gmem_tma_basis_stride}); +} + +template +CUTE_HOST_RTC +auto +make_tma_copy_atom(CopyOp, + Tensor const& gtensor, // Full GMEM Tensor + SLayout const& slayout, // CTA Tile of SMEM, potentially swizzled + uint32_t const& num_multicast, // The number of CTAs involved in multicasting + Layout const& cta_v_map) // V: CTA val idx -> gmem mode +{ + // + // TMA truncated layout + // + + auto smem_swizzle = get_swizzle_portion(slayout); + auto smem_layout = get_nonswizzle_portion(slayout); + + auto tma_gbasis = detail::construct_tma_gbasis(gtensor, smem_layout, cta_v_map); + + // + // Construct the TMA Desc and the strides of the TMA Tensor + // + + auto [tma_desc, aux_params] = detail::make_tma_copy_desc(gtensor, + tma_gbasis, + smem_swizzle, + num_multicast); + + // + // Construct the Copy_Traits + // + + constexpr int num_bits_per_tma = size(tma_gbasis) * sizeof_bits_v; + using Traits = Copy_Traits, decltype(aux_params)>; + using Atom = Copy_Atom; + + Traits tma_traits{tma_desc, aux_params}; + +#if 0 + print("num_bits_per_tma : "); print(num_bits_per_tma); print("\n"); + print("g_stride_bases : "); print(tma_traits.aux_params_.g_stride_); print("\n"); +#endif + + // Return the Copy_Atom + return Atom{tma_traits}; +} + +// The "logical TMA tid" is a map from the CTA rank to its logical id +// within the instruction. It works like a mask or ordering on the +// CTAs. For non-multicast TMA, all CTAs should map to 0. For +// multicast TMA of size 4, CTAs will be mapped to {0,1,2,3}. +template +CUTE_HOST_RTC +auto +make_tma_copy_tiled(CopyOp const& copy_op, + Tensor const& gtensor, // Full GMEM Tensor + SLayout const& slayout, // CTA Tile of SMEM + Layout const& cta_t_map, // T: CTA thr idx -> logical TMA tid + Layout const& cta_v_map) // V: CTA val idx -> gmem mode +{ + Copy_Atom atom = make_tma_copy_atom(copy_op, gtensor, slayout, + cosize(cta_t_map), cta_v_map); + + // + // Construct the TiledCopy + // + + [[maybe_unused]] auto cta_tiler = product_each(shape(cta_v_map)); + + auto num_elems_per_tma = size<1>(typename decltype(atom)::RefLayout{}) / static_value>(); + + // smem idx -> smem coord + auto inv_smem_layout = right_inverse(get_nonswizzle_portion(slayout)); + // CTA V -> smem_coord + auto layout_v = composition(inv_smem_layout, num_elems_per_tma); + // Scale that up to cover all of the smem_coords + auto layout_V = tile_to_shape(make_layout(layout_v), size(cta_v_map)); + // CTA T -> smem idx + auto layout_t = make_layout(cosize(cta_t_map), shape_div(num_elems_per_tma, cosize(cta_t_map))); + // CTA TID -> smem coord + auto layout_T = composition(inv_smem_layout, composition(layout_t, cta_t_map)); + // Combine with the T mapping + [[maybe_unused]] auto layout_TV = make_layout(layout_T, layout_V); + +#if 0 + print("cta_tiler : "); print(cta_tiler); print("\n"); + print("layout_v : "); print(layout_v); print("\n"); + print("layout_V : "); print(layout_V); print("\n"); + print("layout_t : "); print(layout_t); print("\n"); + print("layout_T : "); print(layout_T); print("\n"); + print("layout_TV : "); print(layout_TV); print("\n"); +#endif + + return TiledCopy{atom}; +} + +} // end namespace detail + +/** Make a CuTe CTA-collective TiledCopy for a TMA operation. + * + * @param CopyOp The target copy operation: SM90_TMA_LOAD, SM90_TMA_LOAD_MULTICAST, SM90_TMA_STORE + * @param gtensor The GMEM Tensor to be involved in the TMA. + * @param slayout The SMEM Layout to be involved in the TMA. + * @param cta_tile The CTA-local tile that each CTA will be tiling GMEM with. + * This is often the blk_shape that is used to tile the GMEM for CTAs: + * local_tile(gtensor, blk_shape, blk_coord) -> CTA-local tile of gtensor + * @param cluster_size When using SM90_TMA_LOAD_MULTICAST, this can be a (static) power-of-2 <= 16 + * defining the multicast size (used to further partition the SMEM) + * Else, static-1 + * + * This code attempts to maximize the TMA box size. It does this by tracing + * the SMEM "vector" -- the inverse of the smem layout -- to find the largest + * contiguous array of smem that can be written to/from global memory given + * the constraints that the TMA instruction imposes. + * + * This is accomplished by assigning "basis" strides to the GMEM to track which + * modes of SMEM map to which modes of GMEM, then reorder the modes of GMEM according + * to the SMEM vector, and then using those GMEM/SMEM modes to fill in the desc. + * + * Examples: + using T = float; + T* gptr = nullptr; + + { + // Simple 2D + Tensor gtensor = make_tensor(gptr, make_shape(1024, 256), GenRowMajor{}); // K-Major GMEM + auto slayout = make_layout(make_shape(_64{}, _32{}), GenRowMajor{}); // K-Major SMEM + auto tma = make_tma_copy(SM90_TMA_LOAD{}, gtensor, slayout); + } + + { + // GMMA 2D + Tensor gtensor = make_tensor(gptr, make_shape(1024, 256)); // MN-Major GMEM + auto slayout = tile_to_shape(GMMA::Layout_MN_SW128_Atom{}, make_shape(_128{},_64{})); // MN-Major Swizzled+Tiled 128x64 SMEM + auto tma = make_tma_copy(SM90_TMA_LOAD{}, gtensor, slayout); + } + + { + // 3D + Tensor gtensor = make_tensor(gptr, make_shape(1024, 32, 512), make_stride(64, Int<1>{}, 65536)); // GMEM + auto slayout = make_layout(make_shape(_16{}, _8{}, _2{}), make_stride(_16{}, _1{}, _8{})); // SMEM w/ same major-mode + auto tma = make_tma_copy(SM90_TMA_LOAD{}, gtensor, slayout); + } + + { + // cuTENSOR 4D + auto layout = make_shape(make_shape(32,40),make_shape(make_shape(8,8),656)); // GMEM + auto cta_tile = make_shape(_128{},make_shape(_32{},_2{})); // GMEM Tiling: + // Take 128-elem from m: m0 must divide 128, + // m-last may be predicated + // Take 32-elem from k0, 2-elem from k1 + auto slayout = make_layout(cta_tile); // Col-Major SMEM + auto tma = make_tma_copy(SM90_TMA_LOAD{}, gtensor, slayout, cta_tile, Int<1>{}); + } + * + * Check the TMA box size and desc: + print("TMA Box size: "); print(typename decltype(tma)::Tiler_MN{}); print("\n"); + print("TMA desc : "); print(tma.tma_desc_); print("\n"); + * + * Usage: + Tensor mA = tma_a.get_tma_tensor(make_shape(M,N)); // (M,N) TMA coord tensor + Tensor gA = local_tile(mA, cta_tile, cta_coord); // (BLK_M,BLK_N) TMA coord tensor for this CTA + Tensor sA = make_tensor(make_smem_ptr(sptr), slayout); // (BLK_M,BLK_N) SMEM tensor + + auto cta_tma = tma.get_slice(cta_idx_in_cluster); // Slice for multicast partitioning + Tensor tAgA = cta_tma.partition_S(gA); // Partition for src + Tensor tAsA = cta_tma.partition_D(sA); // Partition for dst + + copy(tma.with(barrier, mcast_mask), tAgA, tAsA); // copy with supporting TMA params + */ +template +CUTE_HOST_RTC +auto +make_tma_copy(CopyOp const& copy_op, + Tensor const& gtensor, + SLayout const& slayout, + CTA_Tiler const& cta_tiler, + Cluster_Size const& cluster_size) +{ + if constexpr (cute::is_same_v || + cute::is_same_v) { + return make_im2col_tma_copy(copy_op, + gtensor, + slayout, + cta_tiler, + cluster_size); + } else { + auto cta_v_tile = make_identity_layout(shape(gtensor)).compose(cta_tiler); + auto cta_t_tile = make_layout(cluster_size); + // Prefer TmaInternalType if specified. Fallback to GEngine::value_type + using TmaType = conditional_t::value, typename GEngine::value_type, TmaInternalType>; + return detail::make_tma_copy_tiled(copy_op, + gtensor, slayout, + cta_t_tile, cta_v_tile); + } +} + +// Explicit defaulting +template +CUTE_HOST_RTC +auto +make_tma_copy(CopyOp const& copy_op, + Tensor const& gtensor, + SLayout const& slayout) +{ + return make_tma_copy(copy_op, gtensor, slayout, product_each(shape(slayout)), Int<1>{}); +} + +// Explicit defaulting +template +CUTE_HOST_RTC +auto +make_tma_copy(CopyOp const& copy_op, + Tensor const& gtensor, + SLayout const& slayout, + Cluster_Size const& cluster_size) +{ + return make_tma_copy(copy_op, gtensor, slayout, product_each(shape(slayout)), cluster_size); +} + +//////////////////////////////////// +// Experimental Make TMA Atom and Partitioner +/////////////////////////////////// + +template > +CUTE_HOST_RTC +auto +make_tma_atom(CopyOp const& copy_op, + Tensor const& gtensor, + SLayout const& slayout, + CTA_Tiler const& cta_tiler, + Cluster_Size const& cluster_size = {}) +{ + auto cta_v_tile = make_identity_layout(shape(gtensor)).compose(cta_tiler); + // Prefer TmaInternalType if specified. Fallback to GEngine::value_type + using TmaType = conditional_t::value, typename GEngine::value_type, TmaInternalType>; + return detail::make_tma_copy_atom(copy_op, + gtensor, slayout, + size(cluster_size), cta_v_tile); +} + +// The "VectorCopy Partitioner" for TMA +template +CUTE_DEVICE +auto +tma_partition(Copy_Atom const& copy_atom, + CtaCoord const& cta_coord, + Layout const& cta_layout, // T: CTA coord -> logical multicast id + Tensor const& stensor, // SMEM Tensor (TMATile, Rest...) + Tensor const& gtensor) // GMEM Tensor (TMATile, Rest...) +{ + CUTE_STATIC_ASSERT_V(size<0>(stensor) == size<0>(gtensor)); + + // Invert the smem to get the largest contiguous vector in the smem layout + Layout inv_smem_layout = right_inverse(get_nonswizzle_portion(layout<0>(stensor))); + // Scale that up to cover all of the smem_coords + Layout layout_v = tile_to_shape(make_layout(inv_smem_layout), size<0>(stensor)); + + // Factor out the single-instrucion portion + Layout tma_layout_v = make_layout(Int::NumValSrc>{}); + auto layout_V = make_tile(logical_divide(layout_v, tma_layout_v)); + + // Append with _ until we cover all Rest... modes + auto glayout_V = append(layout_V, _); + auto slayout_V = append(layout_V, _); + // Transform tile mode and coalesce + Tensor gtensor_v = coalesce(gtensor.compose(glayout_V), Shape>{}); // ((TMA,TMA_Iter), Rest...) + Tensor stensor_v = coalesce(stensor.compose(slayout_V), Shape>{}); // ((TMA,TMA_Iter), Rest...) + +#if 0 + if (thread0()) { + print("cta_coord : "); print(cta_coord); print("\n"); + print("cta_layout : "); print(cta_layout); print("\n"); + print("gtensor : "); print(gtensor); print("\n"); + print("stensor : "); print(stensor); print("\n"); + print("layout_V : "); print(layout_V); print("\n"); + print("gtensor_v : "); print(gtensor_v); print("\n"); + print("stensor_v : "); print(stensor_v); print("\n"); + } +#endif + + // Offset inside the TMA-mode for the multicast + auto multicast_offset = cta_layout(cta_coord) * (size(tma_layout_v) / cosize(cta_layout)); + auto multicast_coord = make_coord(make_coord(multicast_offset, Int<0>{})); + auto gcoord = append(multicast_coord, Int<0>{}); + auto scoord = append(multicast_coord, Int<0>{}); + + Tensor gresult = domain_offset(gcoord, gtensor_v); + Tensor sresult = domain_offset(scoord, stensor_v); + + return cute::make_tuple(gresult, sresult); +} + +// Explicit defaults for cta_coord and cta_layout +template +CUTE_DEVICE +auto +tma_partition(Copy_Atom const& copy_atom, + Tensor const& stensor, // SMEM Tensor (TMATile, Rest...) + Tensor const& gtensor) // GMEM Tensor (TMATile, Rest...) +{ + return tma_partition(copy_atom, Int<0>{}, Layout<_1,_0>{}, stensor, gtensor); +} + +// TMA Multicast Masks Calculation +template +CUTE_HOST_DEVICE constexpr +uint16_t +create_tma_multicast_mask(CtaLayout const& cta_layout_vmnk, + CtaCoord const& cta_coord_vmnk) +{ + auto cta_coord_slicer = replace(cta_coord_vmnk, _); + auto [cta_layout, elected_cta] = slice_and_offset(cta_coord_slicer, cta_layout_vmnk); + + uint16_t mcast_mask = 0; + if constexpr (rank_v == 1 and depth_v <= 1 and + not is_static::value) { + // Get the instruction code -- optimized for dynamic flat-rank-1 cta_layout + mcast_mask = uint16_t(1); + // Smear by stride<0> (may want to predicate on stride<0> mag?) + mcast_mask |= mcast_mask << (1*stride<0>(cta_layout)); + mcast_mask |= mcast_mask << (2*stride<0>(cta_layout)); + mcast_mask |= mcast_mask << (4*stride<0>(cta_layout)); + mcast_mask |= mcast_mask << (8*stride<0>(cta_layout)); + // Select shape<0> + mcast_mask &= (uint16_t(-1) >> (16 - shape<0>(cta_layout) * stride<0>(cta_layout))); + } else { + // Get the instruction code -- generic path + for (int i = 0; i < size(cta_layout); ++i) { + mcast_mask |= uint16_t(1) << cta_layout(i); + } + } + // Shift by the instruction's elected block rank (dynamic) + mcast_mask <<= elected_cta; + return mcast_mask; +} + +//////////////////////////////////// +// Make TMA copy A/B/C +/////////////////////////////////// + +template +CUTE_HOST_RTC +auto +make_tma_copy_A_sm90(CopyOp const& copy_op, + Tensor const& gtensor, + SLayout const& slayout, + CTA_Tiler const& cta_tiler, + Cluster_Size const& cluster_size) +{ + // Keep only MK modes from MNK + auto cta_tiler_mk = remove<1>(cta_tiler); + + // mcast along N mode for this M load, if any + auto cluster_size_n = size<1>(cluster_size); + + if constexpr (cute::is_same_v) { + return make_im2col_tma_copy(copy_op, + gtensor, + slayout, + cta_tiler_mk, + cluster_size_n); + } else { + auto cta_v_tile = make_identity_layout(shape(gtensor)).compose(cta_tiler_mk); + auto cta_t_tile = make_layout(cluster_size_n); + + // Prefer TmaInternalType if specified. Fallback to GEngine::value_type + using TmaType = conditional_t::value, typename GEngine::value_type, TmaInternalType>; + auto tma_copy = detail::make_tma_copy_tiled(copy_op, gtensor, slayout, cta_t_tile, cta_v_tile); + return tma_copy; + } +} + +template +CUTE_HOST_RTC +auto +make_tma_copy_B_sm90(CopyOp const& copy_op, + Tensor const& gtensor, + SLayout const& slayout, + CTA_Tiler const& cta_tiler, + Cluster_Size const& cluster_size) +{ + // Keep only NK modes from MNK + auto cta_tiler_nk = remove<0>(cta_tiler); + + // mcast along M mode for this N load, if any + auto cluster_size_m = size<0>(cluster_size); + + if constexpr (cute::is_same_v) { + return make_im2col_tma_copy(copy_op, + gtensor, + slayout, + cta_tiler_nk, + cluster_size_m); + } else { + auto cta_v_tile = make_identity_layout(shape(gtensor)).compose(cta_tiler_nk); + auto cta_t_tile = make_layout(cluster_size_m); + + // Prefer TmaInternalType if specified. Fallback to GEngine::value_type + using TmaType = conditional_t::value, typename GEngine::value_type, TmaInternalType>; + auto tma_copy = detail::make_tma_copy_tiled(copy_op, gtensor, slayout, cta_t_tile, cta_v_tile); + return tma_copy; + } +} + +template +CUTE_HOST_RTC +auto +make_tma_copy_C_sm90(CopyOp const& copy_op, + Tensor const& gtensor, + SLayout const& slayout, + CTA_Tiler const& cta_tiler) +{ + // Keep only MN modes from MNK + auto cta_tiler_mn = remove<2>(cta_tiler); + + if constexpr (cute::is_same_v || + cute::is_same_v) { + return make_im2col_tma_copy(copy_op, + gtensor, + slayout, + cta_tiler_mn, + _1{}); + } else { + auto cta_v_tile = make_identity_layout(shape(gtensor)).compose(cta_tiler_mn); + + // No multicast, so only 1 CTA involved + auto cta_t_map = Layout<_1,_0>{}; + + // Prefer TmaInternalType if specified. Fallback to GEngine::value_type + using TmaType = conditional_t::value, typename GEngine::value_type, TmaInternalType>; + auto tma_copy = detail::make_tma_copy_tiled(copy_op, gtensor, slayout, cta_t_map, cta_v_tile); + return tma_copy; + } +} +} // end namespace cute diff --git a/include/cute/atom/copy_traits_sm90_tma_swizzle.hpp b/include/cute/atom/copy_traits_sm90_tma_swizzle.hpp new file mode 100644 index 0000000000..3286e72b36 --- /dev/null +++ b/include/cute/atom/copy_traits_sm90_tma_swizzle.hpp @@ -0,0 +1,93 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +/// @file copy_traits_sm90_tma_swizzle.hpp +/// @brief Functions for converting swizzle layout to TMA descriptor + +#if !defined(__CUDACC_RTC__) +#include +#endif + +#include +#include + +namespace cute::detail { + +template +CUTE_HOST_DEVICE constexpr +TMA::SmemSwizzleBits +get_tma_swizzle_bits(Swizzle) +{ + if constexpr (M == 4) { + switch (B) { + default: static_assert(0 <= B && B <= 3, "Expected B = 0,1,2, or 3 when M == 4. Unsupported layout swizzle."); + case 3: return TMA::SmemSwizzleBits::B128; + case 2: return TMA::SmemSwizzleBits::B64; + case 1: return TMA::SmemSwizzleBits::B32; + case 0: return TMA::SmemSwizzleBits::DISABLE; + } + } else + { + static_assert(M < 0, "Unsupported layout swizzle."); + } +} + +template +TMA::SmemSwizzleBits +get_tma_swizzle_bits(Layout const& layout) +{ + return get_tma_swizzle_bits(get_swizzle_portion(layout)); +} + +template +CUTE_HOST_DEVICE constexpr +TMA::SmemSwizzleBase +get_tma_swizzle_base(Swizzle) +{ + if constexpr (M == 4) { + static_assert(0 <= B && B <= 3, "Expected B = 0,1,2, or 3 when M == 4. Unsupported layout swizzle."); + static_assert(S == 3, "Expected S = 3 when M == 4. Unsupported layout swizzle."); + return TMA::SmemSwizzleBase::SWIZZLE_BASE_16B; + } + else { + static_assert(M == 4, "Expected 128b=16B=(2^4)B base swizzle."); + } +} + +template +TMA::SmemSwizzleBase +get_tma_swizzle_base(Layout const& layout) +{ + return get_tma_swizzle_base(get_swizzle_portion(layout)); +} + +} // namespace cute::detail diff --git a/include/cute/atom/mma_atom.hpp b/include/cute/atom/mma_atom.hpp new file mode 100644 index 0000000000..7cb4fe3df2 --- /dev/null +++ b/include/cute/atom/mma_atom.hpp @@ -0,0 +1,1112 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include +#include +#include +#include + +namespace cute { + +template +struct MMA_Atom; + +template +struct MMA_Atom : MMA_Atom> +{}; + +template +struct MMA_Atom> + : MMA_Traits +{ + using MMA_Op = MMAOperation; + using Traits = MMA_Traits; + + // Element value types from the MMA_Traits + using ValTypeD = typename Traits::ValTypeD; + using ValTypeA = typename Traits::ValTypeA; + using ValTypeB = typename Traits::ValTypeB; + using ValTypeC = typename Traits::ValTypeC; + + // Thr-Val layouts from the MMA_Traits + using Shape_MNK = typename Traits::Shape_MNK; + using ThrID = typename Traits::ThrID; + using LayoutC_TV = typename Traits::CLayout; + using LayoutA_TV = typename Traits::ALayout; + using LayoutB_TV = typename Traits::BLayout; + + // Fragment value types from the MMA_Traits (optional, defaults to Val type) + using FrgTypeD = typename detail::FrgTypeC_or_Default::type; + using FrgTypeA = typename detail::FrgTypeA_or_Default::type; + using FrgTypeB = typename detail::FrgTypeB_or_Default::type; + using FrgTypeC = typename detail::FrgTypeC_or_Default::type; + + // Additional Trait parameters/transformations + template + CUTE_HOST_DEVICE + auto + with(TraitsArgs&&... args) const { + auto traits = Traits::with(static_cast(args)...); + return MMA_Atom{traits}; + } + + // + // Tensor call interfaces + // + + // Cast, check, and call fma + template + CUTE_HOST_DEVICE constexpr + void + call(Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) const + { + static_assert(DLayout::rank == 1, "Expected rank-1 D tensor"); + static_assert(ALayout::rank == 1, "Expected rank-1 A tensor"); + static_assert(BLayout::rank == 1, "Expected rank-1 B tensor"); + static_assert(CLayout::rank == 1, "Expected rank-1 C tensor"); + + return mma_unpack(static_cast(*this), D, A, B, C); + } + + // Three arguments reproduces C + template + CUTE_HOST_DEVICE constexpr + void + call(Tensor const& A, + Tensor const& B, + Tensor & C) const + { + return call(C, A, B, C); + } + + // + // make_fragment_A|B|C + // These functions are awkward as they expect already-partitioned tensors + // resulting from a previous call to partition_A|B|C + // The reasoning is that we can inspect the layout of the partitioned data + // and attempt to match it in generated fragment to promote vectorization + // when copying from partition to fragment. + // + + template + CUTE_HOST_DEVICE static constexpr + auto + make_fragment_C(CTensor&& ctensor) + { + // Check that this tensor is likely already partitioned + CUTE_STATIC_ASSERT_V(rank(ctensor) >= Int<3>{}); // VMN + CUTE_STATIC_ASSERT_V(size<0>(ctensor) == size<1>(LayoutC_TV{})); + // C is a bit special because we are after accumulators here + // The input/output type doesn't have to match the accumulator type + //static_assert(std::is_same::value_type>::value, "Expecting ValTypeC type"); + + // We'll never base the accumulator layout on the input tensor layout, so just return a FrgTypeC tensor + return make_tensor(shape(ctensor)); + } + + template + CUTE_HOST_DEVICE static constexpr + auto + make_fragment_A(ATensor&& atensor) + { + // Check that this tensor is likely already partitioned + CUTE_STATIC_ASSERT_V(rank(atensor) >= Int<3>{}); // VMK + CUTE_STATIC_ASSERT_V(size<0>(atensor) == size<1>(LayoutA_TV{})); + + if constexpr (has_dereference::value) { + // If the intended FrgTypeA is a view (of the current tensor), forward the whole + static_assert(is_same::value_type>::value + , "Expecting ValTypeA type"); + return make_tensor(static_cast(atensor)); + } else { + // Else, the intended FrgTypeA is a value type, construct a new tensor with a fragment layout + return make_fragment_like(atensor); + } + + CUTE_GCC_UNREACHABLE; + } + + template + CUTE_HOST_DEVICE static constexpr + auto + make_fragment_B(BTensor&& btensor) + { + // Check that this tensor is likely already partitioned + CUTE_STATIC_ASSERT_V(rank(btensor) >= Int<3>{}); // VNK + CUTE_STATIC_ASSERT_V(size<0>(btensor) == size<1>(LayoutB_TV{})); + + if constexpr (has_dereference::value) { + // If the intended FrgTypeB is a view (of the current tensor), forward the whole + static_assert(is_same::value_type>::value + , "Expecting ValTypeB type"); + return make_tensor(static_cast(btensor)); + } else { + // Else, the intended FrgTypeB is a value type, construct a new tensor with a fragment layout + return make_fragment_like(btensor); + } + + CUTE_GCC_UNREACHABLE; + } +}; + +// +// A tiling of mma atoms +// + +template +struct ThrMMA; + +// @tparam MMA_Atom The MMA_Atom to use in the TiledMMA +// @tparam AtomLayoutMNK The MNK-tiling of the Atom to be performed. +// @tparam PermuationsMNK Permutations to apply to each MNK-mode before tiling for the Atom. +template > +struct TiledMMA : MMA_Atom +{ + using Atom = MMA_Atom; + using AtomShape_MNK = typename MMA_Atom::Shape_MNK; + using AtomThrID = typename MMA_Atom::ThrID; + using AtomLayoutC_TV = typename MMA_Atom::LayoutC_TV; + using AtomLayoutA_TV = typename MMA_Atom::LayoutA_TV; + using AtomLayoutB_TV = typename MMA_Atom::LayoutB_TV; + + static_assert( rank_v == 3, "TiledMMA requires rank-3 AtomLayoutMNK"); + static_assert( rank_v == 3, "TiledMMA requires rank-3 PermutationMNK"); + static_assert( is_tuple::value, "TiledMMA requires independent permutations of MNK."); + static_assert(is_static::value, "TiledMMA requires static permutations of MNK."); + + using ThrLayoutVMNK = decltype(tiled_product(AtomThrID{}, AtomLayoutMNK{})); + ThrLayoutVMNK thr_layout_vmnk_; + + CUTE_HOST_DEVICE constexpr + TiledMMA(MMA_Atom const& mma_atom = {}, AtomLayoutMNK const& thr_layout_mnk = {}) + : MMA_Atom(mma_atom), + thr_layout_vmnk_(tiled_product(AtomThrID{}, thr_layout_mnk)) {} + + CUTE_HOST_DEVICE constexpr auto + get_thr_layout_vmnk() const { + return thr_layout_vmnk_; + } + + // Tile a tensor or a layout from shape + // (M,N,...) + // to shape + // ((ThrV,(ThrM,ThrN)),(FrgV,(RestM,RestN,...))) + // where + // ThrV: The threads local to an MMA. layout<0>(ThrLayoutVMNK): ThrV -> thread_idx + // ThrM: The threads tiled in M. layout<1>(ThrLayoutVMNK): ThrM -> thread_idx + // ThrN: The threads tiled in N. layout<2>(ThrLayoutVMNK): ThrN -> thread_idx + // FrgV: The values local to an MMA. + // RestM: The values tiled in M. + // RestN: The values tiled in N. + template + CUTE_HOST_DEVICE constexpr + auto + thrfrg_C(CTensor&& ctensor) const + { + CUTE_STATIC_ASSERT_V(rank(ctensor) >= Int<2>{}); + // Reorder the tensor for the TiledAtom + auto t_tile = make_tile(permutation_mnk<0>(), + permutation_mnk<1>()); + auto t_tensor = logical_divide(ctensor, t_tile); // (PermM,PermN) + + // Tile the tensor for the Atom + auto c_tile = make_tile(make_layout(size<0>(AtomShape_MNK{})), + make_layout(size<1>(AtomShape_MNK{}))); + auto c_tensor = zipped_divide(t_tensor, c_tile); // ((AtomM,AtomN),(RestM,RestN)) + + // Transform the Atom mode from (M,K) to (Thr,Val) + auto tv_tensor = c_tensor.compose(AtomLayoutC_TV{},_); // ((ThrV,FrgV),(RestM,RestN)) + + // Tile the tensor for the C-threads + auto thr_tile = make_tile(_, + make_tile(make_layout(size<1>(thr_layout_vmnk_)), + make_layout(size<2>(thr_layout_vmnk_)))); + auto thr_tensor = zipped_divide(tv_tensor, thr_tile); // ((ThrV,(ThrM,ThrN)),(FrgV,(RestM,RestN))) + + return thr_tensor; + } + + // Tile a tensor or a layout from shape + // (M,K,...) + // to shape + // ((ThrV,(ThrM,ThrK)),(FrgV,(RestM,RestK,...))) + // where + // ThrV: The threads local to an MMA. layout<0>(ThrLayoutVMNK): ThrV -> thread_idx + // ThrM: The threads tiled in M. layout<1>(ThrLayoutVMNK): ThrM -> thread_idx + // ThrK: The threads tiled in K. layout<3>(ThrLayoutVMNK): ThrK -> thread_idx + // FrgV: The values local to an MMA. + // RestM: The values tiled in M. + // RestK: The values tiled in K. + template + CUTE_HOST_DEVICE constexpr + auto + thrfrg_A(ATensor&& atensor) const + { + CUTE_STATIC_ASSERT_V(rank(atensor) >= Int<2>{}); + // Reorder the tensor for the TiledAtom + auto t_tile = make_tile(permutation_mnk<0>(), + permutation_mnk<2>()); + auto t_tensor = logical_divide(atensor, t_tile); // (PermM,PermK) + + // Tile the tensor for the Atom + auto a_tile = make_tile(make_layout(size<0>(AtomShape_MNK{})), + make_layout(size<2>(AtomShape_MNK{}))); + auto a_tensor = zipped_divide(t_tensor, a_tile); // ((AtomM,AtomK),(RestM,RestK)) + + // Transform the Atom mode from (M,K) to (Thr,Val) + auto tv_tensor = a_tensor.compose(AtomLayoutA_TV{},_); // ((ThrV,FrgV),(RestM,RestK)) + + // Tile the tensor for the Thread + auto thr_tile = make_tile(_, + make_tile(make_layout(size<1>(thr_layout_vmnk_)), + make_layout(size<3>(thr_layout_vmnk_)))); + auto thr_tensor = zipped_divide(tv_tensor, thr_tile); // ((ThrV,(ThrM,ThrK)),(FrgV,(RestM,RestK))) + + return thr_tensor; + } + + // Tile a tensor or a layout from shape + // (N,K,...) + // to shape + // ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK,...))) + // where + // ThrV: The threads local to an MMA. layout<0>(ThrLayoutVMNK): ThrV -> thread_idx + // ThrN: The threads tiled in N. layout<2>(ThrLayoutVMNK): ThrN -> thread_idx + // ThrK: The threads tiled in K. layout<3>(ThrLayoutVMNK): ThrK -> thread_idx + // FrgV: The values local to an MMA. + // RestN: The values tiled in N. + // RestK: The values tiled in K. + template + CUTE_HOST_DEVICE constexpr + auto + thrfrg_B(BTensor&& btensor) const + { + CUTE_STATIC_ASSERT_V(rank(btensor) >= Int<2>{}); + // Reorder the tensor for the TiledAtom + auto t_tile = make_tile(permutation_mnk<1>(), + permutation_mnk<2>()); + auto t_tensor = logical_divide(btensor, t_tile); // (PermN,PermK) + + // Tile the tensor for the Atom + auto b_tile = make_tile(make_layout(size<1>(AtomShape_MNK{})), + make_layout(size<2>(AtomShape_MNK{}))); + auto b_tensor = zipped_divide(t_tensor, b_tile); // ((AtomN,AtomK),(RestN,RestK)) + + // Transform the Atom mode from (M,K) to (Thr,Val) + auto tv_tensor = b_tensor.compose(AtomLayoutB_TV{},_); // ((ThrV,FrgV),(RestN,RestK)) + + // Tile the tensor for the Thread + auto thr_tile = make_tile(_, + make_tile(make_layout(size<2>(thr_layout_vmnk_)), + make_layout(size<3>(thr_layout_vmnk_)))); + auto thr_tensor = zipped_divide(tv_tensor, thr_tile); // ((ThrV,(ThrN,ThrK)),(FrgV,(RestN,RestK))) + + return thr_tensor; + } + + template ::value)> + CUTE_HOST_DEVICE constexpr + auto + get_slice(ThrIdx const& thr_idx) const + { + auto thr_vmnk = thr_layout_vmnk_.get_flat_coord(thr_idx); + return ThrMMA{*this, thr_vmnk}; + } + + template ::value)> + CUTE_HOST_DEVICE constexpr + auto + get_thread_slice(ThrIdx const& thr_idx) const + { + return get_slice(thr_idx); + } + + // + // Utility for printing and visualization + // + + // The permutation applied to the MNK-mode data + template + CUTE_HOST_DEVICE constexpr + auto + permutation_mnk() const { + static_assert(0 <= I && I < 3); + auto perm = get(PermutationMNK{}); + return conditional_return(is_underscore{}, size(AtomShape_MNK{}) * size(get_thr_layout_vmnk()), perm); + } + + // The size of the MNK-mode + template + CUTE_HOST_DEVICE constexpr + auto + tile_size_mnk() const { + static_assert(0 <= I && I < 3); + return size(permutation_mnk()); + } + + CUTE_HOST_DEVICE constexpr + auto + get_layoutC_MN() const + { + // (M,N) -> (M,N) + auto ref_C = make_layout(make_shape(tile_size_mnk<0>(), tile_size_mnk<1>())); + // (cthrid,val) -> (M,N) + auto layoutC_TV = thrfrg_C(ref_C); + // (M,N) -> (cthrid,frg) + auto layoutC_MN = right_inverse(layoutC_TV).with_shape(shape(ref_C)); + + // cthrid = (v,m,n) -> thr_idx + auto thrID_C = thr_layout_vmnk_(_,_,_,Int<0>{}); + + return cute::make_tuple(layoutC_MN, thrID_C); + } + + CUTE_HOST_DEVICE constexpr + auto + get_layoutC_TV() const + { + // (M,N) -> (M,N) + auto ref_C = make_layout(make_shape(tile_size_mnk<0>(), tile_size_mnk<1>())); + // (cthrid,val) -> (M,N) + auto layoutC_TV = thrfrg_C(ref_C); + + // thr_idx -> (ThrV,ThrM,ThrN,ThrK) + auto thridx_2_thrid = right_inverse(thr_layout_vmnk_); + + // (thr_idx,val) -> (M,N) + return layoutC_TV.compose(thridx_2_thrid, _); + } + + CUTE_HOST_DEVICE constexpr + auto + get_layoutA_MK() const + { + // (M,K) -> (M,K) + auto ref_A = make_layout(make_shape(tile_size_mnk<0>(), tile_size_mnk<2>())); + // (athrid,val) -> (M,K) + auto layoutA_TV = thrfrg_A(ref_A); + // (M,K) -> (athrid,frg) + auto layoutA_MK = right_inverse(layoutA_TV).with_shape(shape(ref_A)); + + // athrid = (v,m,k) -> thr_idx + auto thrID_A = thr_layout_vmnk_(_,_,Int<0>{},_); + + return cute::make_tuple(layoutA_MK, thrID_A); + } + + CUTE_HOST_DEVICE constexpr + auto + get_layoutA_TV() const + { + // (M,K) -> (M,K) + auto ref_A = make_layout(make_shape(tile_size_mnk<0>(), tile_size_mnk<2>())); + // (athrid,val) -> (M,K) + auto layoutA_TV = thrfrg_A(ref_A); + + // (ThrV,(ThrM,ThrK)) -> (ThrV,(ThrM,ThrN,ThrK)) + auto atile = make_tile(_, + make_tile(make_layout(make_shape (size<1>(thr_layout_vmnk_), size<2>(thr_layout_vmnk_)), + make_stride( Int<1>{} , Int<0>{} )), + _)); + + // thr_idx -> (ThrV,ThrM,ThrN,ThrK) + auto thridx_2_thrid = right_inverse(thr_layout_vmnk_); + + // (thr_idx,val) -> (M,K) + return thrfrg_A(ref_A).compose(atile, _).compose(thridx_2_thrid, _); + } + + CUTE_HOST_DEVICE constexpr + auto + get_layoutB_NK() const + { + // (N,K) -> (N,K) + auto ref_B = make_layout(make_shape(tile_size_mnk<1>(), tile_size_mnk<2>())); + // (bthrid,val) -> (N,K) + auto layoutB_TV = thrfrg_B(ref_B); + // (N,K) -> (bthrid,frg) + auto layoutB_NK = right_inverse(layoutB_TV).with_shape(shape(ref_B)); + + // bthrid = (v,n,k) -> thr_idx + auto thrID_B = thr_layout_vmnk_(_,Int<0>{},_,_); + + return cute::make_tuple(layoutB_NK, thrID_B); + } + + CUTE_HOST_DEVICE constexpr + auto + get_layoutB_TV() const + { + // (N,K) -> (N,K) + auto ref_B = make_layout(make_shape(tile_size_mnk<1>(), tile_size_mnk<2>())); + // (bthrid,val) -> (N,K) + auto layoutB_TV = thrfrg_B(ref_B); + + // (ThrV,(ThrN,ThrK)) -> (ThrV,(ThrM,ThrN,ThrK)) + auto btile = make_tile(_, + make_tile(make_layout(make_shape (size<1>(thr_layout_vmnk_), size<2>(thr_layout_vmnk_)), + make_stride( Int<0>{} , Int<1>{} )), + _)); + + // thr_idx -> (ThrV,ThrM,ThrN,ThrK) + auto thridx_2_thrid = right_inverse(thr_layout_vmnk_); + + // (thr_idx,val) -> (N,K) + return thrfrg_B(ref_B).compose(btile, _).compose(thridx_2_thrid, _); + } +}; + +template +struct ThrMMA : TiledMMA +{ + ThrVMNK thr_vmnk_; + + template + CUTE_HOST_DEVICE constexpr + auto + partition_C(CTensor&& ctensor) const + { + auto thr_tensor = make_tensor(static_cast(ctensor).data(), this->thrfrg_C(ctensor.layout())); + + auto thr_vmn = make_coord(get<0>(thr_vmnk_), make_coord(get<1>(thr_vmnk_), get<2>(thr_vmnk_))); + return thr_tensor(thr_vmn, make_coord(_, repeat(thr_tensor)>(_))); + } + + template + CUTE_HOST_DEVICE constexpr + auto + partition_A(ATensor&& atensor) const + { + auto thr_tensor = make_tensor(static_cast(atensor).data(), this->thrfrg_A(atensor.layout())); + + auto thr_vmk = make_coord(get<0>(thr_vmnk_), make_coord(get<1>(thr_vmnk_), get<3>(thr_vmnk_))); + return thr_tensor(thr_vmk, make_coord(_, repeat(thr_tensor)>(_))); + } + + template + CUTE_HOST_DEVICE constexpr + auto + partition_B(BTensor&& btensor) const + { + auto thr_tensor = make_tensor(static_cast(btensor).data(), this->thrfrg_B(btensor.layout())); + + auto thr_vnk = make_coord(get<0>(thr_vmnk_), make_coord(get<2>(thr_vmnk_), get<3>(thr_vmnk_))); + return thr_tensor(thr_vnk, make_coord(_, repeat(thr_tensor)>(_))); + } + + template + CUTE_HOST_DEVICE constexpr + auto + partition_fragment_C(CTensor&& ctensor) const + { + return TiledMMA::make_fragment_C(partition_C(ctensor)); + } + + template + CUTE_HOST_DEVICE constexpr + auto + partition_fragment_A(ATensor&& atensor) const + { + return TiledMMA::make_fragment_A(partition_A(atensor)); + } + + template + CUTE_HOST_DEVICE constexpr + auto + partition_fragment_B(BTensor&& btensor) const + { + return TiledMMA::make_fragment_B(partition_B(btensor)); + } +}; + +// +// These tile the MMA_Atom as a whole +// + +template >, + class Permutations = Tile> +CUTE_HOST_DEVICE constexpr +auto +make_tiled_mma(MMA_Atom const& mma_atom, + MMAThrLayout const& thr_layout = {}, + Permutations const& permutations = {}) +{ + auto thr_layout_mnk = append<3>(thr_layout, Layout<_1,_0>{}); + auto permutation_mnk = append<3>(permutations, _); + + return TiledMMA, + decltype(thr_layout_mnk), + decltype(permutation_mnk)>{mma_atom, thr_layout_mnk}; +} + +template >, + class Permutations = Tile> +CUTE_HOST_DEVICE constexpr +auto +make_tiled_mma(MMA_Op const&, + MMAThrLayout const& thr_layout = {}, + Permutations const& permutations = {}) +{ + // Attempt to wrap in an MMA_Atom<> and forward + return make_tiled_mma(MMA_Atom{}, thr_layout, permutations); +} + +// +// partition_fragment_C -- static context +// + +template +CUTE_HOST_DEVICE constexpr +auto +partition_shape_C(TiledMMA const& mma, Shape_MN const& shape_MN) +{ + auto dummy = make_layout(shape(shape_MN)); + auto dummy_tv = mma.thrfrg_C(dummy); + // Slice+rearrange like partition_C + auto dummy_v = dummy_tv(Int<0>{}, make_coord(_, repeat(_))); + return shape(dummy_v); + +} + + +template +CUTE_HOST_DEVICE constexpr +auto +partition_fragment_C(TiledMMA const& mma, Shape_MN const& shapeMN) +{ + return make_tensor::FrgTypeC>(partition_shape_C(mma, shapeMN)); +} + +// partition_fragment_A and partition_fragment_B often depend on the +// layout of A and B and/or the thread_idx that is requesting the partition. +// For these reasons, they should not be used in a static context. +// See TiledMMA::get_slice(thr_idx).partition_fragment_A(tensorA) instead. + +template +CUTE_HOST_DEVICE constexpr +auto +partition_shape_A(TiledMMA const& mma, Shape_MK const& shape_MK) +{ + auto dummy = make_layout(shape(shape_MK)); + auto dummy_tv = mma.thrfrg_A(dummy); + // Slice+rearrange like partition_A + auto dummy_v = dummy_tv(Int<0>{}, make_coord(_, repeat(_))); + return shape(dummy_v); + +} + +template +CUTE_HOST_DEVICE constexpr +auto +partition_shape_B(TiledMMA const& mma, Shape_NK const& shape_NK) +{ + auto dummy = make_layout(shape(shape_NK)); + auto dummy_tv = mma.thrfrg_B(dummy); + // Slice+rearrange like partition_B + auto dummy_v = dummy_tv(Int<0>{}, make_coord(_, repeat(_))); + return shape(dummy_v); + +} + +// +// Size +// + +template +CUTE_HOST_DEVICE constexpr +auto +tile_size(TiledMMA const& mma) +{ + return mma.template tile_size_mnk(); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tile_shape(TiledMMA const& mma) +{ + return make_shape(tile_size<0>(mma), tile_size<1>(mma), tile_size<2>(mma)); +} + +// Deprecate? +template +CUTE_HOST_DEVICE constexpr +auto +size(TiledMMA const& mma) +{ + return size(mma.get_thr_layout_vmnk()); +} + +// Alias +template +CUTE_HOST_DEVICE constexpr +auto +thr_size(TiledMMA const& mma) +{ + return size(mma.get_thr_layout_vmnk()); +} + +// +// Display utilities +// + +template +CUTE_HOST_DEVICE +void +print(MMA_Atom> const&) +{ + using Atom = MMA_Atom>; + print("MMA_Atom\n"); + print(" ThrID: "); print(typename Atom::ThrID{}); print("\n"); + print(" Shape_MNK: "); print(typename Atom::Shape_MNK{}); print("\n"); + print(" LayoutA_TV: "); print(typename Atom::LayoutA_TV{}); print("\n"); + print(" LayoutB_TV: "); print(typename Atom::LayoutB_TV{}); print("\n"); + print(" LayoutC_TV: "); print(typename Atom::LayoutC_TV{}); print("\n"); +} + +template +CUTE_HOST_DEVICE +void +print(TiledMMA const& mma) +{ + print("TiledMMA\n"); + print(" ThrLayoutVMNK: "); print(mma.get_thr_layout_vmnk()); print("\n"); + print(" PermutationMNK: "); print(TiledPerm{}); print("\n"); + print(static_cast(mma)); +} + +template +CUTE_HOST_DEVICE +void +print(ThrMMA const& thr_mma) +{ + print("ThrMMA\n"); + print(" Thr VMNK: "); print(thr_mma.thr_vmnk_); print("\n"); + print(static_cast(thr_mma)); +} + +// MMA Atom to LaTeX TikZ +template +CUTE_HOST_DEVICE +void +print_latex(MMA_Atom const& mma_atom, + TikzColorFn color = {}) // lambda(thr_idx,val_idx) -> tikz color string +{ + print_latex(make_tiled_mma(mma_atom)); +} + +// TiledMMA to LaTeX TikZ +template +CUTE_HOST_DEVICE +void +print_latex(TiledMMA const& mma, + TikzColorFn color = {}) // lambda(thr_idx,val_idx) -> tikz color string +{ + auto layout_and_thrid_C = mma.get_layoutC_MN(); + auto layoutC_MN = get<0>(layout_and_thrid_C); + auto thrID_C = get<1>(layout_and_thrid_C); + + auto layout_and_thrid_A = mma.get_layoutA_MK(); + auto layoutA_MK = get<0>(layout_and_thrid_A); + auto thrID_A = get<1>(layout_and_thrid_A); + + auto layout_and_thrid_B = mma.get_layoutB_NK(); + auto layoutB_NK = get<0>(layout_and_thrid_B); + auto thrID_B = get<1>(layout_and_thrid_B); + + print_latex_mma(layoutC_MN, thrID_C, + layoutA_MK, thrID_A, + layoutB_NK, thrID_B); +} + +// MNK MMA Layout to LaTeX TikZ +template +CUTE_HOST_DEVICE +void +print_latex_mma(LayoutC const& C, ThrIDC const& TC, // (m,n) -> (tid,vid) and tid -> thr_idx + LayoutA const& A, ThrIDA const& TA, // (m,k) -> (tid,vid) and tid -> thr_idx + LayoutB const& B, ThrIDB const& TB, // (n,k) -> (tid,vid) and tid -> thr_idx + TikzColorFn color = {}) // lambda(thr_idx,val_idx) -> tikz color string +{ + CUTE_STATIC_ASSERT_V(rank(C) == Int<2>{}); + CUTE_STATIC_ASSERT_V(rank(A) == Int<2>{}); + CUTE_STATIC_ASSERT_V(rank(B) == Int<2>{}); + + assert(size<0>(A) == size<0>(C)); + assert(size<0>(B) == size<1>(C)); + assert(size<1>(A) == size<1>(B)); + + // Commented prints + printf("%% LayoutC: "); print(C); printf("\n"); + printf("%% ThrIDC : "); print(TC); printf("\n"); + printf("%% LayoutA: "); print(A); printf("\n"); + printf("%% ThrIDA : "); print(TA); printf("\n"); + printf("%% LayoutB: "); print(B); printf("\n"); + printf("%% ThrIDB : "); print(TB); printf("\n\n"); + // Header + printf("\\documentclass[convert]{standalone}\n" + "\\usepackage{tikz}\n\n" + "\\begin{document}\n" + "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},every node/.style={minimum size=1cm, outer sep=0pt}]\n\n"); + + // C starting at 0,0 + for (int m = 0; m < size<0>(C); ++m) { + for (int n = 0; n < size<1>(C); ++n) { + int thrid = C(m,n) % size(TC); + int val_idx = C(m,n) / size(TC); + int thr_idx = TC(thrid); + + printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", + color(thr_idx, val_idx), + m, n, + thr_idx, val_idx); + } + } + // Grid + printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (%d,%d) grid (%d,%d);\n\n", + 0, 0, int(size<0>(C)), int(size<1>(C))); + + // A starting at 0,-size<1>(A)-1 + for (int m = 0; m < size<0>(A); ++m) { + for (int k = 0; k < size<1>(A); ++k) { + int thrid = A(m,k) % size(TA); + int val_idx = A(m,k) / size(TA); + int thr_idx = TA(thrid); + + printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", + color(thr_idx, val_idx), + m, k-1-size<1>(A), + thr_idx, val_idx); + } + } + // Grid + printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (%d,%d) grid (%d,%d);\n\n", + 0, int(-size<1>(A)-1), int(size<0>(A)), -1); + // A labels + for (int m = 0, k = -1; m < size<0>(A); ++m) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, k-1-size<1>(A), m); + } + for (int m = -1, k = 0; k < size<1>(A); ++k) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", m, k-1-size<1>(A), k); + } + + // B starting at -size<1>(B)-1,0 + for (int n = 0; n < size<0>(B); ++n) { + for (int k = 0; k < size<1>(B); ++k) { + int thrid = B(n,k) % size(TB); + int val_idx = B(n,k) / size(TB); + int thr_idx = TB(thrid); + + printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", + color(thr_idx, val_idx), + k-1-size<1>(B), n, + thr_idx, val_idx); + } + } + // Grid + printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (%d,%d) grid (%d,%d);\n\n", + int(-size<1>(B)-1), 0, -1, int(size<0>(B))); + // B labels + for (int n = 0, k = -1; n < size<0>(B); ++n) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", k-1-size<1>(B), n, n); + } + for (int n = -1, k = 0; k < size<1>(B); ++k) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", k-1-size<1>(B), n, k); + } + + // Footer + printf("\\end{tikzpicture}\n" + "\\end{document}\n"); +} + +// MNK MMA Layout to console printer +template +CUTE_HOST_DEVICE +void +print_layout_mma(LayoutC const& C, ThrIDC const& TC, // (m,n) -> (tid,vid) and tid -> thr_idx + LayoutA const& A, ThrIDA const& TA, // (m,k) -> (tid,vid) and tid -> thr_idx + LayoutB const& B, ThrIDB const& TB) // (n,k) -> (tid,vid) and tid -> thr_idx +{ + CUTE_STATIC_ASSERT_V(rank(C) == Int<2>{}); + CUTE_STATIC_ASSERT_V(rank(A) == Int<2>{}); + CUTE_STATIC_ASSERT_V(rank(B) == Int<2>{}); + + assert(size<0>(A) == size<0>(C)); + assert(size<0>(B) == size<1>(C)); + assert(size<1>(A) == size<1>(B)); + + int a_width = size<1>(A) * 6 + 4; + + // Print out B (white-shifted) k-by-n + for (int k = 0; k < size<1>(B); ++k) { + // Header + printf("%*s", a_width, ""); + for (int n = 0; n < size<0>(B); ++n) printf("+-----"); + printf("+\n"); + // Values + printf("%*s", a_width, ""); + for (int n = 0; n < size<0>(B); ++n) printf("|T%02dV%1d", int(TB(B(n,k) % size(TB))), int(B(n,k) / size(TB))); + printf("|\n"); + } + // Footer + printf("%*s", a_width, ""); + for (int n = 0; n < size<0>(B); ++n) printf("+-----"); + printf("+\n\n"); + + // Print out A m-by-k and C m-by-n + for (int m = 0; m < size<0>(A); ++m) { + // Header + for (int k = 0; k < size<1>(A); ++k) printf("+-----"); + printf("+ "); + for (int n = 0; n < size<1>(C); ++n) printf("+-----"); + printf("+\n"); + // Values + for (int k = 0; k < size<1>(A); ++k) printf("|T%02dV%1d", int(TA(A(m,k) % size(TA))), int(A(m,k) / size(TA))); + printf("| "); + for (int n = 0; n < size<1>(C); ++n) printf("|T%02dV%1d", int(TC(C(m,n) % size(TC))), int(C(m,n) / size(TC))); + printf("|\n"); + } + // Footer + for (int k = 0; k < size<1>(A); ++k) printf("+-----"); + printf("+ "); + for (int n = 0; n < size<1>(C); ++n) printf("+-----"); + printf("+\n"); +} + +// MNK MMA Layout to SVG -- 8-value color coded by thread +template +CUTE_HOST_DEVICE +void +print_svg_mma(LayoutC const& C, ThrIDC const& TC, // (m,n) -> (tid,vid) and tid -> thr_idx + LayoutA const& A, ThrIDA const& TA, // (m,k) -> (tid,vid) and tid -> thr_idx + LayoutB const& B, ThrIDB const& TB) // (n,k) -> (tid,vid) and tid -> thr_idx +{ + char const *color_map[8] = {"175,175,255", "175,255,175", "255,255,175", + "255,175,175", "210,210,255", "210,255,210", + "255,255,210", "255,210,210"}; + + const int cell_width = 20; + const int cell_height = 20; + + const int page_width = (size<1>(A) + size<0>(B) + 2) * cell_width; + const int page_height = (size<1>(B) + size<0>(A) + 2) * cell_height; + + // header + printf("\n", + page_width, page_height); + + // C + int c_base_x = (size<1>(A) + 2) * cell_width; + int c_base_y = (size<1>(B) + 2) * cell_height; + for (int m = 0; m < cute::size<0>(C); ++m) { + for (int n = 0; n < cute::size<1>(C); ++n) { + + int thrid = C(m, n) % size(TC); + int val_idx = C(m, n) / size(TC); + int thr_idx = TC(thrid); + + int x = n * cell_width + c_base_x; + int y = m * cell_height + c_base_y; + + int thr_x = x + cell_width / 2; + int thr_y = y + cell_height / 4; + int val_x = x + cell_width / 2; + int val_y = y + cell_height * 3 / 4; + + printf("\n", + x, y, cell_width, cell_height, color_map[thr_idx % 8]); + + printf("T%d\n", + thr_x, thr_y, thr_idx); + printf("V%d\n", + val_x, val_y, val_idx); + } + } + + // A + int a_base_x = cell_width; + int a_base_y = (size<1>(B) + 2) * cell_height; + for (int m = 0; m < size<0>(A); ++m) { + for (int k = 0; k < size<1>(A); ++k) { + int thrid = A(m, k) % size(TA); + int val_idx = A(m, k) / size(TA); + int thr_idx = TA(thrid); + + int x = k * cell_width + a_base_x; + int y = m * cell_height + a_base_y; + + int thr_x = x + cell_width / 2; + int thr_y = y + cell_height / 4; + int val_x = x + cell_width / 2; + int val_y = y + cell_height * 3 / 4; + + printf("\n", + x, y, cell_width, cell_height, color_map[thr_idx % 8]); + printf("T%d\n", + thr_x, thr_y, thr_idx); + printf("V%d\n", + val_x, val_y, val_idx); + } + } + + // B + int b_base_x = (size<1>(A) + 2) * cell_width; + int b_base_y = cell_height; + for (int n = 0; n < size<0>(B); ++n) { + for (int k = 0; k < size<1>(B); ++k) { + int thrid = B(n, k) % size(TB); + int val_idx = B(n, k) / size(TB); + int thr_idx = TB(thrid); + + int x = n * cell_width + b_base_x; + int y = k * cell_height + b_base_y; + + int thr_x = x + cell_width / 2; + int thr_y = y + cell_height / 4; + int val_x = x + cell_width / 2; + int val_y = y + cell_height * 3 / 4; + + printf("\n", + x, y, cell_width, cell_height, color_map[thr_idx % 8]); + printf("T%d\n", + thr_x, thr_y, thr_idx); + printf("V%d\n", + val_x, val_y, val_idx); + } + } + + // A labels + for (int m = 0; m < size<0>(A); ++m) { + int x = cell_width / 2; + int y = m * cell_height + cell_height / 2 + a_base_y; + printf("%d\n", + x, y, m); + } + for (int k = 0; k < size<1>(A); ++k) { + int x = cell_width + k * cell_width + cell_width / 2; + int y = -cell_height / 2 + a_base_y; + printf("%d\n", + x, y, k); + } + + // B labels + for (int n = 0; n < size<0>(B); ++n) { + int x = b_base_x + cell_width * n + cell_width / 2; + int y = cell_height / 2; + printf("%d\n", + x, y, n); + } + for (int k = 0; k < size<1>(B); ++k) { + int x = b_base_x - cell_width / 2; + int y = cell_height * (k + 1) + cell_height / 2; + printf("%d\n", + x, y, k); + } + + // footer + printf(""); +} + +template +CUTE_HOST_DEVICE +void +print_svg(MMA_Atom const &mma_atom) { + print_svg(make_tiled_mma(mma_atom)); +} + +template +CUTE_HOST_DEVICE +void +print_svg(TiledMMA const &mma) { + auto layout_and_thrid_C = mma.get_layoutC_MN(); + auto layoutC_MN = get<0>(layout_and_thrid_C); + auto thrID_C = get<1>(layout_and_thrid_C); + + auto layout_and_thrid_A = mma.get_layoutA_MK(); + auto layoutA_MK = get<0>(layout_and_thrid_A); + auto thrID_A = get<1>(layout_and_thrid_A); + + auto layout_and_thrid_B = mma.get_layoutB_NK(); + auto layoutB_NK = get<0>(layout_and_thrid_B); + auto thrID_B = get<1>(layout_and_thrid_B); + + print_svg_mma(layoutC_MN, thrID_C, layoutA_MK, thrID_A, layoutB_NK, thrID_B); +} + +} // namespace cute + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#include +#include +#include +#include +#include +#include +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cute/atom/mma_traits.hpp b/include/cute/atom/mma_traits.hpp new file mode 100644 index 0000000000..0994698a87 --- /dev/null +++ b/include/cute/atom/mma_traits.hpp @@ -0,0 +1,189 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // cute::Tensor +#include // cute::is_rmem +#include // cute::UniversalFMA +#include // cute::detail::explode + +namespace cute +{ + +/** + * concept MMA_Traits + * { + * using ValTypeD = // Logical A-value type + * using ValTypeA = // Logical B-value type + * using ValTypeB = // Logical C-value type + * using ValTypeC = // Logical D-value type (NOTE: Not used? Assumed == ValTypeD) + * + * using FrgTypeA = // A-type consumed by MMA (if ommitted, same as ValTypeA) + * using FrgTypeB = // B_type consumed by MMA (if ommitted, same as ValTypeB) + * using FrgTypeC = // C_type consumed by MMA (if ommitted, same as ValTypeC) + * + * using Shape_MNK = // Logical MxNxK shape of the MMA + * + * using ThrID = // Logical thread id (tid) -> tidx + * + * using ALayout = // (Logical thread id (tid), Logical value id (vid)) -> Flat MK-coord + * using BLayout = // (Logical thread id (tid), Logical value id (vid)) -> Flat NK-coord + * using CLayout = // (Logical thread id (tid), Logical value id (vid)) -> Flat MN-coord + * }; + */ + +template +struct MMA_Traits +{ + static_assert(sizeof(MMAOperation) == 0, "MMA_Traits not implemented for this MMA_Operation."); +}; + +template +struct MMA_Traits> +{ + using ValTypeD = D; + using ValTypeA = A; + using ValTypeB = B; + using ValTypeC = C; + + // Logical shape of the MMA + using Shape_MNK = Shape<_1,_1,_1>; + + // Logical thread id (tid) -> tidx + using ThrID = Layout<_1>; + + // (Logical thread id (tid), Logical value id (vid)) -> coord + + // (tid,vid) -> (m,k) + using ALayout = Layout>; + // (tid,vid) -> (n,k) + using BLayout = Layout>; + // (tid,vid) -> (m,n) + using CLayout = Layout>; +}; + +// Extract an MMA_Op from an MMA_Traits +template +struct MMA_Op {}; + +template +struct MMA_Op> { + using type = MMA_Op_Arg; +}; + +// +// Generic mma_unpack for any MMA_Traits +// + +template +CUTE_HOST_DEVICE constexpr +void +mma_unpack(AnyMMATraits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) +{ + static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); + + // Register value types from the MMA_Operation register arrays + using MMA_Op = typename MMA_Op::type; + using RegTypeD = typename remove_extent::type; + using RegTypeA = typename remove_extent::type; + using RegTypeB = typename remove_extent::type; + using RegTypeC = typename remove_extent::type; + + Tensor rA = recast(A); + Tensor rB = recast(B); + Tensor rD = recast(D); + Tensor rC = recast(C); + + constexpr int RegNumD = extent::value; + constexpr int RegNumA = extent::value; + constexpr int RegNumB = extent::value; + constexpr int RegNumC = extent::value; + + CUTE_STATIC_ASSERT_V(size(rA) == Int{}); + CUTE_STATIC_ASSERT_V(size(rB) == Int{}); + CUTE_STATIC_ASSERT_V(size(rD) == Int{}); + CUTE_STATIC_ASSERT_V(size(rC) == Int{}); + + detail::explode(MMA_Op::fma, + rD, make_int_sequence{}, + rA, make_int_sequence{}, + rB, make_int_sequence{}, + rC, make_int_sequence{}); +} + +// Accept mutable temporaries +template +CUTE_HOST_DEVICE constexpr +void +mma_unpack(AnyMMATraits const& traits, + Tensor && D, + Tensor const& A, + Tensor const& B, + Tensor const& C) +{ + mma_unpack(traits, D, A, B, C); +} + +namespace detail { + +template +struct FrgTypeA_or_Default { using type = typename X::ValTypeA; }; +template +struct FrgTypeA_or_Default> { using type = typename X::FrgTypeA; }; + +template +struct FrgTypeB_or_Default { using type = typename X::ValTypeB; }; +template +struct FrgTypeB_or_Default> { using type = typename X::FrgTypeB; }; + +template +struct FrgTypeC_or_Default { using type = typename X::ValTypeC; }; +template +struct FrgTypeC_or_Default> { using type = typename X::FrgTypeC; }; + +} // end namespace detail + +} // namespace cute diff --git a/include/cute/atom/mma_traits_sm61.hpp b/include/cute/atom/mma_traits_sm61.hpp new file mode 100644 index 0000000000..f72a639400 --- /dev/null +++ b/include/cute/atom/mma_traits_sm61.hpp @@ -0,0 +1,73 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include + +namespace cute +{ + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using Shape_MNK = Shape<_1,_1,_4>; + using ThrID = Layout<_1>; + using ALayout = Layout>; + using BLayout = Layout>; + using CLayout = Layout>; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int16_t; + using ValTypeB = int16_t; + using ValTypeC = int32_t; + + using Shape_MNK = Shape<_1,_1,_2>; + using ThrID = Layout<_1>; + using ALayout = Layout>; + using BLayout = Layout>; + using CLayout = Layout>; +}; + +} // namespace cute diff --git a/include/cute/atom/mma_traits_sm70.hpp b/include/cute/atom/mma_traits_sm70.hpp new file mode 100644 index 0000000000..f0702a9617 --- /dev/null +++ b/include/cute/atom/mma_traits_sm70.hpp @@ -0,0 +1,198 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include + +namespace cute +{ + +namespace { + +// Logical thread id to thread idx (quadpair) +using SM70_QuadPair = Layout, + Stride<_1,_16>>; +// (T8,V4) -> (M8,K4) +using SM70_8x4_Row = Layout, + Stride<_1,_8>>; +// (T8,V4) -> (M8,K4) +using SM70_8x4_Col = Layout,_4>, + Stride,_1>>; +// (T8,V8) -> (M8,N8) +using SM70_8x8_16b = Layout, + Stride<_1,_8>>; +// (T8,V8) -> (M8,N8) +using SM70_8x8_32b = Layout,Shape <_2,_2, _2>>, + Stride,Stride<_8,_2,_32>>>; + +} + +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using Shape_MNK = Shape<_8,_8,_4>; + using ThrID = SM70_QuadPair; + using ALayout = SM70_8x4_Row; + using BLayout = SM70_8x4_Row; + using CLayout = SM70_8x8_16b; +}; + +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using Shape_MNK = Shape<_8,_8,_4>; + using ThrID = SM70_QuadPair; + using ALayout = SM70_8x4_Col; + using BLayout = SM70_8x4_Col; + using CLayout = SM70_8x8_16b; +}; + +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using Shape_MNK = Shape<_8,_8,_4>; + using ThrID = SM70_QuadPair; + using ALayout = SM70_8x4_Col; + using BLayout = SM70_8x4_Row; + using CLayout = SM70_8x8_16b; +}; + +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using Shape_MNK = Shape<_8,_8,_4>; + using ThrID = SM70_QuadPair; + using ALayout = SM70_8x4_Row; + using BLayout = SM70_8x4_Col; + using CLayout = SM70_8x8_16b; +}; + +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using Shape_MNK = Shape<_8,_8,_4>; + using ThrID = SM70_QuadPair; + using ALayout = SM70_8x4_Row; + using BLayout = SM70_8x4_Row; + using CLayout = SM70_8x8_32b; +}; + +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using Shape_MNK = Shape<_8,_8,_4>; + using ThrID = SM70_QuadPair; + using ALayout = SM70_8x4_Col; + using BLayout = SM70_8x4_Col; + using CLayout = SM70_8x8_32b; +}; + +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using Shape_MNK = Shape<_8,_8,_4>; + using ThrID = SM70_QuadPair; + using ALayout = SM70_8x4_Col; + using BLayout = SM70_8x4_Row; + using CLayout = SM70_8x8_32b; +}; + +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using Shape_MNK = Shape<_8,_8,_4>; + using ThrID = SM70_QuadPair; + using ALayout = SM70_8x4_Row; + using BLayout = SM70_8x4_Col; + using CLayout = SM70_8x8_32b; +}; + +/////////////////////////////////////////////////////////////////////////////// +} // namespace cute diff --git a/include/cute/atom/mma_traits_sm75.hpp b/include/cute/atom/mma_traits_sm75.hpp new file mode 100644 index 0000000000..1d3f51961c --- /dev/null +++ b/include/cute/atom/mma_traits_sm75.hpp @@ -0,0 +1,81 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include + +namespace cute +{ + +template <> +struct MMA_Traits +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using Shape_MNK = Shape<_16,_8,_8>; + using ThrID = Layout<_32>; + using ALayout = Layout,Shape < _2,_2>>, + Stride,Stride<_16,_8>>>; + using BLayout = Layout,_2>, + Stride,_8>>; + using CLayout = Layout,Shape < _2,_2>>, + Stride,Stride<_16,_8>>>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using Shape_MNK = Shape<_8,_8,_16>; + using ThrID = Layout<_32>; + using ALayout = Layout,_4>, + Stride,_8>>; + using BLayout = Layout,_4>, + Stride,_8>>; + using CLayout = Layout,_2>, + Stride,_8>>; +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cute diff --git a/include/cute/atom/mma_traits_sm80.hpp b/include/cute/atom/mma_traits_sm80.hpp new file mode 100644 index 0000000000..5f7e73e467 --- /dev/null +++ b/include/cute/atom/mma_traits_sm80.hpp @@ -0,0 +1,690 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include +#include +#include + +namespace cute +{ + +namespace { + +// (T32,V1) -> (M8,N8) +using SM80_8x4 = Layout,_1>, + Stride,_0>>; +// (T32,V2) -> (M8,N8) +using SM80_8x8_Row = Layout,_2>, + Stride,_8>>; +// (T32,V4) -> (M8,N16) +using SM80_8x16_Row = Layout,_4>, + Stride,_8>>; +// (T32,V4) -> (M16,N8) +using SM80_16x8_Row = Layout,Shape < _2,_2>>, + Stride,Stride<_16,_8>>>; + +} + +/////////////////////////////////////////////////////////////////////////////// +//////////////////////// fp16 = fp16 * fp16 + fp16 //////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using Shape_MNK = Shape<_16,_8,_8>; + using ThrID = Layout<_32>; + using ALayout = SM80_16x8_Row; + using BLayout = SM80_8x8_Row; + using CLayout = SM80_16x8_Row; +}; + +template <> +struct MMA_Traits +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using Shape_MNK = Shape<_16,_8,_16>; + using ThrID = Layout<_32>; + using ALayout = Layout,Shape < _2,_2, _2>>, + Stride,Stride<_16,_8,_128>>>; + using BLayout = Layout,Shape <_2, _2>>, + Stride,Stride<_8,_64>>>; + using CLayout = SM80_16x8_Row; +}; + +/////////////////////////////////////////////////////////////////////////////// +//////////////////////// fp32 = fp16 * fp16 + fp32 //////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; +}; + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; +}; + +/////////////////////////////////////////////////////////////////////////////// +//////////////////////// fp32 = bf16 * bf16 + fp32 //////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; +}; + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; +}; + +/////////////////////////////////////////////////////////////////////////////// +//////////////////////// fp32 = tf32 * tf32 + fp32 //////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = float; + using ValTypeA = cutlass::tfloat32_t; + using ValTypeB = cutlass::tfloat32_t; + using ValTypeC = float; + + using Shape_MNK = Shape<_16,_8,_4>; + using ThrID = Layout<_32>; + using ALayout = Layout,_2>, + Stride,_8>>; + using BLayout = SM80_8x4; + using CLayout = SM80_16x8_Row; +}; + +template <> +struct MMA_Traits +{ + using ValTypeD = float; + using ValTypeA = cutlass::tfloat32_t; + using ValTypeB = cutlass::tfloat32_t; + using ValTypeC = float; + + using Shape_MNK = Shape<_16,_8,_8>; + using ThrID = Layout<_32>; + using ALayout = Layout,Shape <_2, _2>>, + Stride,Stride<_8,_64>>>; + using BLayout = Layout, _2>, + Stride,_32>>; + using CLayout = SM80_16x8_Row; +}; + +/////////////////////////////////////////////////////////////////////////////// +//////////////////////// fp64 = fp64 * fp64 + fp64 //////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = double; + using ValTypeA = double; + using ValTypeB = double; + using ValTypeC = double; + + using Shape_MNK = Shape<_8,_8,_4>; + using ThrID = Layout<_32>; + using ALayout = SM80_8x4; + using BLayout = SM80_8x4; + using CLayout = SM80_8x8_Row; +}; + +// Custom complex fp64 MMA composed of 4 fp64 MMAs -- same layouts +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = complex; + using ValTypeA = complex; + using ValTypeB = complex; + using ValTypeC = complex; +}; + +// Custom complex fp64 MMA composed of 3 fp64 MMAs -- same layouts +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = typename SM80_8x8x4_GC64C64C64GC64_TN::GaussComplex; + using ValTypeA = complex; + using ValTypeB = complex; + using ValTypeC = typename SM80_8x8x4_GC64C64C64GC64_TN::GaussComplex; +}; + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////// s32 = s8 * s8 + s32 /////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using Shape_MNK = Shape<_8,_8,_16>; + using ThrID = Layout<_32>; + using ALayout = SM80_8x16_Row; + using BLayout = SM80_8x16_Row; + using CLayout = SM80_8x8_Row; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using Shape_MNK = Shape<_16,_8,_16>; + using ThrID = Layout<_32>; + using ALayout = Layout,Shape < _4,_2>>, + Stride,Stride<_16,_8>>>; + using BLayout = SM80_8x16_Row; + using CLayout = SM80_16x8_Row; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using Shape_MNK = Shape<_16,_8,_32>; + using ThrID = Layout<_32>; + using ALayout = Layout,Shape < _4,_2, _2>>, + Stride,Stride<_16,_8,_256>>>; + using BLayout = Layout, Shape <_4, _2>>, + Stride, Stride<_8,_128>>>; + using CLayout = SM80_16x8_Row; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////// s32 = s8 * u8 + s32 /////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////// s32 = u8 * s8 + s32 /////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////// s32 = u8 * u8 + s32 /////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////// s32 = s4 * s4 + s32 /////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = int4b_t; + using ValTypeB = int4b_t; + using ValTypeC = int32_t; + + using Shape_MNK = Shape<_8, _8, _32>; + using ThrID = Layout<_32>; + // (T32,V8) -> (M8,N32) + using ALayout = Layout, Shape <_8>>, + Stride, Stride<_8>>>; + using BLayout = Layout, Shape <_8>>, + Stride, Stride<_8>>>; + using CLayout = SM80_8x8_Row; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = int4b_t; + using ValTypeB = int4b_t; + using ValTypeC = int32_t; + + using Shape_MNK = Shape<_16, _8, _32>; + using ThrID = Layout<_32>; + // (T32,V16) -> (M16,N32) + using ALayout = Layout, Shape < _8, _2>>, + Stride, Stride<_16, _8>>>; + // (T32,V8) -> (M8,N32) + using BLayout = Layout, Shape <_8>>, + Stride, Stride<_8>>>; + using CLayout = SM80_16x8_Row; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = int4b_t; + using ValTypeB = int4b_t; + using ValTypeC = int32_t; + + using Shape_MNK = Shape<_16, _8, _64>; + using ThrID = Layout<_32>; + // (T32,V32) -> (M16,N64) + using ALayout = Layout, Shape < _8, _2, _2>>, + Stride, Stride<_16, _8, _512>>>; + // (T32,V16) -> (M8,N64) + using BLayout = Layout, Shape <_8, _2>>, + Stride, Stride<_8, _256>>>; + using CLayout = SM80_16x8_Row; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////// s32 = s4 * u4 + s32 /////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = int4b_t; + using ValTypeB = uint4b_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = int4b_t; + using ValTypeB = uint4b_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = int4b_t; + using ValTypeB = uint4b_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////// s32 = u4 * s4 + s32 /////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = uint4b_t; + using ValTypeB = int4b_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = uint4b_t; + using ValTypeB = int4b_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = uint4b_t; + using ValTypeB = int4b_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////// s32 = u4 * u4 + s32 /////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = uint4b_t; + using ValTypeB = uint4b_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = uint4b_t; + using ValTypeB = uint4b_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template <> +struct MMA_Traits + : MMA_Traits { + using ValTypeD = int32_t; + using ValTypeA = uint4b_t; + using ValTypeB = uint4b_t; + using ValTypeC = int32_t; +}; + +template <> +struct MMA_Traits + : MMA_Traits {}; + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////// s32 = b1 ^ b1 + s32 /////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = cute::uint1b_t; + using ValTypeB = cute::uint1b_t; + using ValTypeC = int32_t; + + using Shape_MNK = Shape<_16,_8,_256>; + using ThrID = Layout<_32>; + using ALayout = Layout,Shape<_32,_2,_2>>, + Stride,Stride<_16,_8,_2048>>>; + using BLayout = Layout,Shape<_32,_2>>, + Stride,Stride< _8,_1024>>>; + using CLayout = SM80_16x8_Row; +}; + +/////////////////////////////////////////////////////////////////////////////// +/////////////////////////// s32 = b1 & b1 + s32 /////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +template <> +struct MMA_Traits + : MMA_Traits {}; + +template<> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = cute::uint1b_t; + using ValTypeB = cute::uint1b_t; + using ValTypeC = int32_t; + + using Shape_MNK = Shape<_8,_8,_128>; + using ThrID = Layout<_32>; + using ALayout = Layout,_32>, + Stride,_8>>; + using BLayout = Layout,_32>, + Stride,_8>>; + using CLayout = SM80_8x8_Row; +}; + +template <> +struct MMA_Traits + :MMA_Traits {}; + +template<> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = cute::uint1b_t; + using ValTypeB = cute::uint1b_t; + using ValTypeC = int32_t; + + using Shape_MNK = Shape<_16,_8,_128>; + using ThrID = Layout<_32>; + using ALayout = Layout,Shape<_32,_2>>, + Stride,Stride>>>; + using BLayout = Layout,_32>, + Stride,_8>>; + using CLayout = SM80_16x8_Row; +}; + +template <> +struct MMA_Traits + :MMA_Traits {}; + +} // end namespace cute diff --git a/include/cute/atom/mma_traits_sm90.hpp b/include/cute/atom/mma_traits_sm90.hpp new file mode 100644 index 0000000000..b2ced3f878 --- /dev/null +++ b/include/cute/atom/mma_traits_sm90.hpp @@ -0,0 +1,144 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include + +#include + +namespace cute { + +/////////////////////////////////////////////////////////////////////////////// +//////////////////////// fp64 = fp64 * fp64 + fp64 //////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +using SM90_16x8x4_F64F64F64F64_TN = SM90::MMA_16x8x4_F64F64F64F64_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = double; + using ValTypeA = double; + using ValTypeB = double; + using ValTypeC = double; + + using Shape_MNK = Shape<_16,_8,_4>; + using ThrID = Layout<_32>; + using ALayout = Layout,_2>, + Stride,_8>>; + using BLayout = Layout,_1>, + Stride,_0>>; + using CLayout = Layout,Shape < _2,_2>>, + Stride,Stride<_16,_8>>>; +}; + +using SM90_16x8x8_F64F64F64F64_TN = SM90::MMA_16x8x8_F64F64F64F64_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = double; + using ValTypeA = double; + using ValTypeB = double; + using ValTypeC = double; + + using Shape_MNK = Shape<_16,_8,_8>; + using ThrID = Layout<_32>; + using ALayout = Layout,Shape <_2, _2>>, + Stride,Stride<_8,_64>>>; + using BLayout = Layout, _2>, + Stride,_32>>; + using CLayout = Layout,Shape < _2,_2>>, + Stride,Stride<_16,_8>>>; +}; + +using SM90_16x8x16_F64F64F64F64_TN = SM90::MMA_16x8x16_F64F64F64F64_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = double; + using ValTypeA = double; + using ValTypeB = double; + using ValTypeC = double; + + using Shape_MNK = Shape<_16,_8,_16>; + using ThrID = Layout<_32>; + using ALayout = Layout,Shape <_2, _4>>, + Stride,Stride<_8,_64>>>; + using BLayout = Layout, _4>, + Stride,_32>>; + using CLayout = Layout,Shape < _2,_2>>, + Stride,Stride<_16,_8>>>; +}; + +/////////////////////////////////////////////////////////////////////////////////// +//////////////////////// cfp64 = cfp64 * cfp64 + cfp64 //////////////////////////// +/////////////////////////////////////////////////////////////////////////////////// + +using SM90_16x8x4_C64C64C64C64_TN = SM90::MMA_16x8x4_C64C64C64C64_TN; + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = complex; + using ValTypeA = complex; + using ValTypeB = complex; + using ValTypeC = complex; +}; + +using SM90_16x8x8_C64C64C64C64_TN = SM90::MMA_16x8x8_C64C64C64C64_TN; + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = complex; + using ValTypeA = complex; + using ValTypeB = complex; + using ValTypeC = complex; +}; + +using SM90_16x8x16_C64C64C64C64_TN = SM90::MMA_16x8x16_C64C64C64C64_TN; + +template <> +struct MMA_Traits + : MMA_Traits +{ + using ValTypeD = complex; + using ValTypeA = complex; + using ValTypeB = complex; + using ValTypeC = complex; +}; + +} // end namespace cute diff --git a/include/cute/atom/mma_traits_sm90_gmma.hpp b/include/cute/atom/mma_traits_sm90_gmma.hpp new file mode 100644 index 0000000000..8f59ff55b4 --- /dev/null +++ b/include/cute/atom/mma_traits_sm90_gmma.hpp @@ -0,0 +1,8998 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // cute::smem_ptr_flag +#include // cute::smem_sparse_ptr_flag +#include // cute::Swizzle +#include // cute::Tensor +#include // cute::LayoutType +#include // cute::SM90_64x8x16_F16F16F16_SS, etc +#include // cute::MMA_Traits +#include // cute::ComposedLayout +#include // cute::is_static + +namespace cute { + +// Fence between the async destination accumulators of GMMA & source for their dependent use +template +CUTE_HOST_DEVICE +void +warpgroup_fence_operand(Tensor& frg) { + CUTE_STATIC_ASSERT(is_static::value); + if constexpr (is_same_v) { + auto f32_frg = recast(frg); + CUTE_UNROLL + for (int i = 0; i < size(f32_frg); ++i) { + warpgroup_fence_operand(f32_frg(i)); + } + } + else { + CUTE_STATIC_ASSERT(is_rmem::value); + auto u32_frg = recast(frg); + CUTE_UNROLL + for (int i = 0; i < size(u32_frg); ++i) { + warpgroup_fence_operand(u32_frg(i)); + } + } +} + +namespace SM90::GMMA { + +/////////////////////////////////////////// +// Common layouts for GMMA Shared Memory // +/////////////////////////////////////////// + +// M|N-major GMMA layouts in units of bits +using Layout_MN_INTER_Atom_Bits = ComposedLayout, smem_ptr_flag, Layout,Stride<_1, _128>>>; +using Layout_MN_SW32_Atom_Bits = ComposedLayout, smem_ptr_flag, Layout,Stride<_1, _256>>>; +using Layout_MN_SW64_Atom_Bits = ComposedLayout, smem_ptr_flag, Layout,Stride<_1, _512>>>; +using Layout_MN_SW128_Atom_Bits = ComposedLayout, smem_ptr_flag, Layout,Stride<_1,_1024>>>; + +// K-major GMMA layouts in units of bits +using Layout_K_INTER_Atom_Bits = ComposedLayout, smem_ptr_flag, Layout,Stride< _128,_1>>>; +using Layout_K_SW32_Atom_Bits = ComposedLayout, smem_ptr_flag, Layout,Stride< _256,_1>>>; +using Layout_K_SW64_Atom_Bits = ComposedLayout, smem_ptr_flag, Layout,Stride< _512,_1>>>; +using Layout_K_SW128_Atom_Bits = ComposedLayout, smem_ptr_flag, Layout,Stride<_1024,_1>>>; + +// M|N-major layouts in units of Type +template +using Layout_MN_INTER_Atom = decltype(upcast::value>(Layout_MN_INTER_Atom_Bits{})); +template +using Layout_MN_SW32_Atom = decltype(upcast::value>(Layout_MN_SW32_Atom_Bits{})); +template +using Layout_MN_SW64_Atom = decltype(upcast::value>(Layout_MN_SW64_Atom_Bits{})); +template +using Layout_MN_SW128_Atom = decltype(upcast::value>(Layout_MN_SW128_Atom_Bits{})); + +// K-major layouts in units of Type +template +using Layout_K_INTER_Atom = decltype(upcast::value>(Layout_K_INTER_Atom_Bits{})); +template +using Layout_K_SW32_Atom = decltype(upcast::value>(Layout_K_SW32_Atom_Bits{})); +template +using Layout_K_SW64_Atom = decltype(upcast::value>(Layout_K_SW64_Atom_Bits{})); +template +using Layout_K_SW128_Atom = decltype(upcast::value>(Layout_K_SW128_Atom_Bits{})); + +// With GMMA::Major param +template +using Layout_INTER_Atom = typename conditional, + Layout_K_INTER_Atom>::type; +template +using Layout_SW32_Atom = typename conditional, + Layout_K_SW32_Atom>::type; +template +using Layout_SW64_Atom = typename conditional, + Layout_K_SW64_Atom>::type; +template +using Layout_SW128_Atom = typename conditional, + Layout_K_SW128_Atom>::type; + +// +// Tensor (position-dependent swizzle) to LayoutType utility +// + +template +CUTE_HOST_DEVICE constexpr +LayoutType +layout_type(Tensor> const&) +{ + static_assert(is_same::value, + "Expected uint128_t type in LayoutType conversion."); + + using Swizzle = get_swizzle_t; + constexpr int B = Swizzle::num_bits; + constexpr int M = Swizzle::num_base; + constexpr int S = Swizzle::num_shft; + + static_assert(M == 4, "Unsupported layout swizzle"); + static_assert(0 <= B && B <= 3, "Unsupported layout swizzle"); + static_assert(S == 3, "Unsupported layout swizzle"); + + switch (B) { + case 0: return LayoutType::INTERLEAVE; + case 1: return LayoutType::B32; + case 2: return LayoutType::B64; + case 3: return LayoutType::B128; + } + return LayoutType::INTERLEAVE; // ERROR +} + +/////////////////////////////////////////////////////////////////////////////// +// Construction method for GMMA Descriptors +/////////////////////////////////////////////////////////////////////////////// + +/** +* /////////////////////////////// +* // make_gmma_desc // +* /////////////////////////////// +* Each GmmaDescriptor Major-MN describes a canonical layout of the form +* +* LayoutType::INTERLEAVE : Swizzle<0,4,3> o smem_ptr o ((T,1,m),(8,k)):((1,T,SBO),(1T,LBO)) +* LayoutType::B32 : Swizzle<1,4,3> o smem_ptr o ((T,2,m),(8,k)):((1,T,LBO),(2T,SBO)) +* LayoutType::B64 : Swizzle<2,4,3> o smem_ptr o ((T,4,m),(8,k)):((1,T,LBO),(4T,SBO)) +* LayoutType::B128 : Swizzle<3,4,3> o smem_ptr o ((T,8,m),(8,k)):((1,T,LBO),(8T,SBO)) +* +* where +* T : sizeof(uint128_t) / sizeof(value_type) +* m : integer in [1,16] corresponding to GMMA shape +* k : integer in [1,32] corresponding to GMMA shape +* SBO: stride byte offset +* LBO: leading byte offset +* +* See GMMA::Layout_MN_XXX_Atom for building canonical GmmaDescriptor Major-MN layouts. +* For example, +* auto smem_layout = tile_to_shape(Layout_MN_SW128_Atom{}, Shape<_128,_64>{}); +* is guaranteed to be accepted by make_gmma_desc for appropriate value_type. +* +* ////////////////////////////// +* // make_gmma_desc // +* ////////////////////////////// +* Each GmmaDescriptor Major-K describes a canonical layout of the form +* +* LayoutType::INTERLEAVE : Swizzle<0,4,3> o smem_ptr o ((8,m),(T,2)):((1T,SBO),(1,LBO)) +* LayoutType::B32 : Swizzle<1,4,3> o smem_ptr o ((8,m),(T,2)):((2T,SBO),(1, T )) +* LayoutType::B64 : Swizzle<2,4,3> o smem_ptr o ((8,m),(T,2)):((4T,SBO),(1, T )) +* LayoutType::B128 : Swizzle<3,4,3> o smem_ptr o ((8,m),(T,2)):((8T,SBO),(1, T )) +* +* See GMMA::Layout_K_XXX_Atom for building canonical GmmaDescriptor Major-K layouts. +* For example, +* auto smem_layout = tile_to_shape(Layout_K_SW128_Atom{}, Shape<_128,_64>{}); +* is guaranteed to be accepted by make_gmma_desc for appropriate value_type. +*/ +template +CUTE_HOST_DEVICE constexpr +GmmaDescriptor +make_gmma_desc(Tensor const& tensor) +{ + static_assert(is_smem::value, "GMMA Descriptors can only be constructed on smem."); + static_assert(TLayout::rank == 2, "GMMA Descriptors can only be constructed on rank-2 tensors."); + using value_type = typename TEngine::value_type; + + Tensor u128_tensor = recast(tensor); + + // Result + GmmaDescriptor desc; + + // Layout type + constexpr LayoutType LAYOUT_TYPE = layout_type(u128_tensor); + desc.bitfield.layout_type_ = uint8_t(LAYOUT_TYPE); + + // Start address (4LSB not included) + uint32_t start_address = cast_smem_ptr_to_uint(raw_pointer_cast(u128_tensor.data())); + desc.bitfield.start_address_ = static_cast(start_address >> 4); + + constexpr uint8_t base_offset = 0; + desc.bitfield.base_offset_ = base_offset; + + // LayoutType meta + constexpr int W = LAYOUT_TYPE == LayoutType::INTERLEAVE ? 1 : + LAYOUT_TYPE == LayoutType::B32 ? 2 : + LAYOUT_TYPE == LayoutType::B64 ? 4 : + LAYOUT_TYPE == LayoutType::B128 ? 8 : -1; + + if constexpr (MajorMode == Major::MN) + { + /* In units of uint128_t, each GmmaDescriptor Major-MN describes a canonical layout of the form + * + * LayoutType::INTERLEAVE : Swizzle<0,4,3> o smem_ptr o ((1,n),(8,k)):((X,SBO),(1,LBO)) + * LayoutType::B32 : Swizzle<1,4,3> o smem_ptr o ((2,n),(8,k)):((1,LBO),(2,SBO)) + * LayoutType::B64 : Swizzle<2,4,3> o smem_ptr o ((4,n),(8,k)):((1,LBO),(4,SBO)) + * LayoutType::B128 : Swizzle<3,4,3> o smem_ptr o ((8,n),(8,k)):((1,LBO),(8,SBO)) + */ + static_assert(size<1>(u128_tensor) == Int<(256 / cute::sizeof_bits::value)>{} || // A and B in dense MMA + size<1>(u128_tensor) == Int<(128 / cute::sizeof_bits::value)>{} || // A in sparse MMA + size<1>(u128_tensor) == Int<(512 / cute::sizeof_bits::value)>{}, // B in sparse MMA + "Not a canonical GMMA_MN Layout: Expected K-size 256/sizeof_bits for dense or (128|512)/sizeof_bits for sparse."); + + // Construct the canonical GMMA T Layout with shape ((W,n),(8,2)) + Layout canonical_layout = logical_divide(layout(u128_tensor), make_tile(Layout,_1>{}, Layout,_1>{})); + + // Check ranks of canonical + CUTE_STATIC_ASSERT_V(rank<0>(canonical_layout) == Int<2>{}, "Not a canonical GMMA_MN Layout: No flat offset mode"); + CUTE_STATIC_ASSERT_V(rank<1>(canonical_layout) == Int<2>{}, "Not a canonical GMMA_MN Layout: No flat offset mode"); + // Check canonical mode strides + constexpr uint32_t stride_00 = stride<0,0>(canonical_layout); + constexpr uint32_t expected_stride_00 = LAYOUT_TYPE == LayoutType::INTERLEAVE ? stride<0,0>(canonical_layout) : 1; + static_assert(stride_00 == expected_stride_00, "Not a canonical GMMA_MN Layout: Expected stride failure."); + constexpr uint32_t stride_10 = stride<1,0>(canonical_layout); + constexpr uint32_t expected_stride_10 = W; + static_assert(stride_10 == expected_stride_10, "Not a canonical GMMA_MN Layout: Expected stride failure."); + + // stride dimension byte offset and leading dimension byte offset (4LSB not included == uint128_t units) + constexpr uint32_t stride_01 = stride<0,1>(canonical_layout); + constexpr uint32_t stride_11 = stride<1,1>(canonical_layout); + + desc.bitfield.stride_byte_offset_ = (LAYOUT_TYPE == LayoutType::INTERLEAVE) ? stride_01 : stride_11; + desc.bitfield.leading_byte_offset_ = (LAYOUT_TYPE == LayoutType::INTERLEAVE) ? stride_11 : stride_01; + } + else if constexpr (MajorMode == Major::K) + { + /* In units of uint128_t, each GmmaDescriptor Major-K describes a canonical layout of the form + * + * LayoutType::INTERLEAVE : Swizzle<0,4,3> o smem_ptr o ((8,n),2):((1,SBO),LBO) + * LayoutType::B32 : Swizzle<1,4,3> o smem_ptr o ((8,n),2):((2,SBO),1) + * LayoutType::B64 : Swizzle<2,4,3> o smem_ptr o ((8,n),2):((4,SBO),1) + * LayoutType::B128 : Swizzle<3,4,3> o smem_ptr o ((8,n),2):((8,SBO),1) + */ + CUTE_STATIC_ASSERT_V(size<0>(u128_tensor) % Int<8>{} == Int<0>{}, // N|M size + "Not a canonical GMMA_K Layout: Expected MN-size multiple of 8."); + CUTE_STATIC_ASSERT_V(size<1>(u128_tensor) == Int<2>{} || size<1>(u128_tensor) == Int<4>{}, // K size + "Not a canonical GMMA_K Layout: Expected K-size 2 for dense or 4 for sparse (in units of uint128_t)."); + + // Construct the canonical GMMA N Layout with shape ((8,n),(2,1)) + Layout canonical_layout = logical_divide(layout(u128_tensor), make_tile(Layout<_8,_1>{}, Layout<_2,_1>{})); + + // Check ranks of canonical + CUTE_STATIC_ASSERT_V(rank<0>(canonical_layout) == Int<2>{}, "Not a canonical GMMA_K Layout: No flat offset mode"); + CUTE_STATIC_ASSERT_V(rank<1>(canonical_layout) == Int<2>{}, "Not a canonical GMMA_K Layout: No flat offset mode"); + // Check canonical mode strides + constexpr uint32_t stride_00 = stride<0,0>(canonical_layout); + constexpr uint32_t expected_stride_00 = W; + static_assert(stride_00 == expected_stride_00, "Not a canonical GMMA_K Layout: Expected stride failure."); + constexpr uint32_t stride_10 = stride<1,0>(canonical_layout); + constexpr uint32_t expected_stride_10 = (LAYOUT_TYPE == LayoutType::INTERLEAVE) ? stride<1,0>(canonical_layout) : 1; + static_assert(stride_10 == expected_stride_10, "Not a canonical GMMA_K Layout: Expected stride failure."); + + // stride dimension byte offset and leading dimension byte offset (4LSB not included == uint128_t units) + constexpr uint32_t stride_01 = stride<0,1>(canonical_layout); + + desc.bitfield.stride_byte_offset_ = stride_01; + desc.bitfield.leading_byte_offset_ = stride_10; + } else { + static_assert(MajorMode != Major::MN && MajorMode != Major::K, "Unrecognized MajorMode!"); + } + +#if 0 + // DEBUG and SANITY + assert((start_address & 0b0000001111) == 0); // Must be 16B aligned (4LSB are 0) no negotiation + assert((start_address & 0b1110000000) == 0); // Assert base_offset is 0, generalize later + if (thread0()) { + print("smem_desc input tensor: "); print(tensor.data()); print(" o "); print(tensor.layout()); print("\n"); + print("smem_desc uint128_t tensor: "); print(u128_tensor.data()); print(" o "); print(u128_tensor.layout()); print("\n"); + //print(" desc canonical layout: "); print(canonical_layout); print("\n"); + print(desc); + } +#endif + + return desc; +} + +/////////////////////////////////////////////////////////////////////////////// +// Higher level GMMA Descriptor utilities +/////////////////////////////////////////////////////////////////////////////// + +struct DescriptorIterator +{ + using reference = GmmaDescriptor; + using element_type = GmmaDescriptor; + using value_type = GmmaDescriptor; + + GmmaDescriptor desc_; + + // Dereference returns the GmmaDescriptor + CUTE_HOST_DEVICE constexpr + reference operator*() const { return desc_; } + + // Advance and return a new GmmaDescriptor + template + CUTE_HOST_DEVICE constexpr + reference operator[](Index const& i) const { return *(*this + i); } + + // Return an advanced iterator + template + CUTE_HOST_DEVICE constexpr + DescriptorIterator operator+(Index const& offset) const + { + return { GmmaDescriptor{desc_ + uint64_t(offset)} }; + } +}; + +template +CUTE_HOST_DEVICE constexpr +GmmaDescriptor +raw_pointer_cast(DescriptorIterator const& ptr) { + return ptr.desc_; +} + +// Recast a DescriptorIterator Tensor to uint64_t, it's RegType in mma_unpack +template +CUTE_HOST_DEVICE constexpr +DescriptorIterator +recast_ptr(DescriptorIterator const& iter) { + static_assert(is_same::value, "Can only cast GmmaDescriptorIterator to uint64_t."); + return iter; // Do nothing, it will still dereference to GmmaDescriptor and decay to uint64_t +} + +CUTE_HOST_DEVICE void +print(DescriptorIterator) { + printf("GMMA::DescriptorIterator"); +} + +// The GMMA Traits below have custom fragment type flags for their smem desc tensors. +// These flags specialize a MakeTensor customization point to correctly make the fragment that is desired. +template +struct smem_desc : DescriptorIterator {}; + +} // end namespace SM90::GMMA + +// Customization point for creating a GMMA::smem_desc Tensor +template +struct MakeTensor> +{ + template + CUTE_HOST_DEVICE constexpr auto + operator()(Tensor const& smem_tensor) + { + static_assert(is_smem::value, "Expected SMEM Tensor to construct a GMMA Desc Tensor"); + return make_tensor(SM90::GMMA::DescriptorIterator{SM90::GMMA::make_gmma_desc(tensor<0>(smem_tensor))}, + replace<0>(recast(smem_tensor).layout(), Layout<_1,_0>{})); + } +}; + +/////////////////////////////////////////////////////////////////////////////// +//////////////////////////// MMA_TRAITS /////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +namespace SM90::GMMA { + +// +// Specialized mma_unpack implementation for SM90 GMMA instructions +// + +template +CUTE_HOST_DEVICE constexpr +void +mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A, + Tensor const& B, + Tensor const& C) +{ + static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem::value, "Expected registers in MMA_Atom::call"); + + // Register value types from the MMA_Operation register arrays + using RegTypeA = typename remove_extent::type; + using RegTypeB = typename remove_extent::type; + using RegTypeC = typename remove_extent::type; + + // SM90 GMMA take three arguments rather than four, try to assert C and D are aliased + static_assert(is_same::value, "GMMA C and D value_type must match."); + static_assert(is_same::value, "GMMA C and D layouts must match."); + // assert((void*)&C == (void*)&D); + + Tensor rA = recast(A); + Tensor rB = recast(B); + Tensor rC = recast(D); // NOTE: D and C are same, so use mutable D + + constexpr int RegNumA = extent::value; + constexpr int RegNumB = extent::value; + constexpr int RegNumC = extent::value; + + CUTE_STATIC_ASSERT_V(size(rA) == Int{}); + CUTE_STATIC_ASSERT_V(size(rB) == Int{}); + CUTE_STATIC_ASSERT_V(size(rC) == Int{}); + + detail::explode(MMA_Op::fma, + rA, make_int_sequence{}, + rB, make_int_sequence{}, + rC, make_int_sequence{}, + &(traits.accumulate_), seq<0>{}); +} + +// Accumulator layouts +template +using CLayout_64xN = Layout,Shape < _2,_2,Int>>, + Stride,Stride<_64,_8, _512>>>; + +using CLayout_64x8 = CLayout_64xN< 8>; +using CLayout_64x16 = CLayout_64xN< 16>; +using CLayout_64x32 = CLayout_64xN< 32>; +using CLayout_64x64 = CLayout_64xN< 64>; +using CLayout_64x96 = CLayout_64xN< 96>; +using CLayout_64x128 = CLayout_64xN<128>; +using CLayout_64x192 = CLayout_64xN<192>; +using CLayout_64x256 = CLayout_64xN<256>; + +// Register source layout for 32-bit value types +using ALayout_64x8 = Layout,Shape < _2, _2>>, + Stride,Stride< _8,_256>>>; + +// Register source layout for 16-bit (sparse 32-bit) value types +using ALayout_64x16 = Layout,Shape < _2,_2, _2>>, + Stride,Stride<_64,_8,_512>>>; + +// Register source layout for 8-bit (sparse 16-bit) value types +using ALayout_64x32 = Layout,Shape < _4,_2, _2>>, + Stride,Stride<_64,_8,_1024>>>; + +// Register source layout for sparse 8-bit value types +using ALayout_64x64 = Layout,Shape < _8,_2, _2>>, + Stride,Stride<_64,_8,_2048>>>; + +// Shared memory source layouts for any value type +template +using ABLayout = Layout,Int>>, + Stride< _0,Stride< _1,Int>>>; + +} // end namespace SM90::GMMA + +using namespace SM90; + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x16_F16F16F16_SS = SM90::GMMA::MMA_64x8x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 8, 16>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x16_F16F16F16_RS = SM90::GMMA::MMA_64x8x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 8, 16>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x16_F16F16F16_SS = SM90::GMMA::MMA_64x16x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 16, 16>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x16_F16F16F16_RS = SM90::GMMA::MMA_64x16x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 16, 16>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x16_F16F16F16_SS = SM90::GMMA::MMA_64x32x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 32, 16>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x16_F16F16F16_RS = SM90::GMMA::MMA_64x32x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 32, 16>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x16_F16F16F16_SS = SM90::GMMA::MMA_64x64x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 64, 16>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x16_F16F16F16_RS = SM90::GMMA::MMA_64x64x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 64, 16>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x16_F16F16F16_SS = SM90::GMMA::MMA_64x96x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 96, 16>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x16_F16F16F16_RS = SM90::GMMA::MMA_64x96x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 96, 16>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x16_F16F16F16_SS = SM90::GMMA::MMA_64x128x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<128, 16>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x16_F16F16F16_RS = SM90::GMMA::MMA_64x128x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<128, 16>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x16_F16F16F16_SS = SM90::GMMA::MMA_64x192x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<192, 16>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x16_F16F16F16_RS = SM90::GMMA::MMA_64x192x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<192, 16>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x16_F16F16F16_SS = SM90::GMMA::MMA_64x256x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<256, 16>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x16_F16F16F16_RS = SM90::GMMA::MMA_64x256x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<256, 16>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x16_F32F16F16_SS = SM90::GMMA::MMA_64x8x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 8, 16>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x16_F32F16F16_RS = SM90::GMMA::MMA_64x8x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 8, 16>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x16_F32F16F16_SS = SM90::GMMA::MMA_64x16x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 16, 16>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x16_F32F16F16_RS = SM90::GMMA::MMA_64x16x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 16, 16>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x16_F32F16F16_SS = SM90::GMMA::MMA_64x32x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 32, 16>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x16_F32F16F16_RS = SM90::GMMA::MMA_64x32x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 32, 16>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x16_F32F16F16_SS = SM90::GMMA::MMA_64x64x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 64, 16>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x16_F32F16F16_RS = SM90::GMMA::MMA_64x64x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 64, 16>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x16_F32F16F16_SS = SM90::GMMA::MMA_64x96x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 96, 16>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x16_F32F16F16_RS = SM90::GMMA::MMA_64x96x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 96, 16>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x16_F32F16F16_SS = SM90::GMMA::MMA_64x128x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<128, 16>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x16_F32F16F16_RS = SM90::GMMA::MMA_64x128x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<128, 16>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x16_F32F16F16_SS = SM90::GMMA::MMA_64x192x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<192, 16>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x16_F32F16F16_RS = SM90::GMMA::MMA_64x192x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<192, 16>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x16_F32F16F16_SS = SM90::GMMA::MMA_64x256x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<256, 16>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x16_F32F16F16_RS = SM90::GMMA::MMA_64x256x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<256, 16>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x8x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 8, 16>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x8x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 8, 16>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x16x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 16, 16>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x16x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 16, 16>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x32x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 32, 16>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x32x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 32, 16>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x64x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 64, 16>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x64x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 64, 16>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x96x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 96, 16>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x96x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 96, 16>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x128x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<128, 16>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x128x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<128, 16>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x192x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<192, 16>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x192x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<192, 16>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x256x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<256, 16>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x256x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<256, 16>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x8x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout< 8, 8>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x8x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout< 8, 8>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x16x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout< 16, 8>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x16x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout< 16, 8>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x32x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout< 32, 8>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x32x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout< 32, 8>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x64x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout< 64, 8>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x64x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout< 64, 8>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x96x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout< 96, 8>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x96x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout< 96, 8>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x128x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<128, 8>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x128x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<128, 8>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x192x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<192, 8>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x192x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<192, 8>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x256x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<256, 8>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x256x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<256, 8>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x8x32_S32S8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x8x32_S32S8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x16x32_S32S8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x16x32_S32S8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x32x32_S32S8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x32x32_S32S8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x64x32_S32S8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x64x32_S32S8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x96x32_S32S8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x96x32_S32S8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x128x32_S32S8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x128x32_S32S8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x192x32_S32S8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x192x32_S32S8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x256x32_S32S8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x256x32_S32S8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x8x32_S32S8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x8x32_S32S8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x16x32_S32S8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x16x32_S32S8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x32x32_S32S8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x32x32_S32S8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x64x32_S32S8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x64x32_S32S8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x96x32_S32S8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x96x32_S32S8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x128x32_S32S8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x128x32_S32S8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x192x32_S32S8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x192x32_S32S8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x256x32_S32S8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x256x32_S32S8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x8x32_S32S8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x8x32_S32S8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x16x32_S32S8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x16x32_S32S8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x32x32_S32S8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x32x32_S32S8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x64x32_S32S8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x64x32_S32S8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x96x32_S32S8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x96x32_S32S8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x128x32_S32S8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x128x32_S32S8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x192x32_S32S8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x192x32_S32S8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x256x32_S32S8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x256x32_S32S8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x8x32_S32S8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x8x32_S32S8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x16x32_S32S8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x16x32_S32S8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x32x32_S32S8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x32x32_S32S8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x64x32_S32S8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x64x32_S32S8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x96x32_S32S8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x96x32_S32S8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x128x32_S32S8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x128x32_S32S8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x192x32_S32S8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x192x32_S32S8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x256x32_S32S8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x256x32_S32S8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x8x32_S32U8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x8x32_S32U8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x16x32_S32U8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x16x32_S32U8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x32x32_S32U8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x32x32_S32U8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x64x32_S32U8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x64x32_S32U8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x96x32_S32U8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x96x32_S32U8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x128x32_S32U8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x128x32_S32U8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x192x32_S32U8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x192x32_S32U8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x256x32_S32U8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x256x32_S32U8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x8x32_S32U8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x8x32_S32U8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x16x32_S32U8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x16x32_S32U8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x32x32_S32U8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x32x32_S32U8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x64x32_S32U8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x64x32_S32U8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x96x32_S32U8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x96x32_S32U8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x128x32_S32U8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x128x32_S32U8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x192x32_S32U8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x192x32_S32U8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x256x32_S32U8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x256x32_S32U8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x8x32_S32U8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x8x32_S32U8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x16x32_S32U8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x16x32_S32U8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x32x32_S32U8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x32x32_S32U8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x64x32_S32U8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x64x32_S32U8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x96x32_S32U8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x96x32_S32U8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x128x32_S32U8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x128x32_S32U8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x192x32_S32U8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x192x32_S32U8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x256x32_S32U8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x256x32_S32U8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x8x32_S32U8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x8x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x8x32_S32U8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x16x32_S32U8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x16x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x16x32_S32U8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x32x32_S32U8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x32x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x32x32_S32U8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x64x32_S32U8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x64x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x64x32_S32U8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x96x32_S32U8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x96x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x96x32_S32U8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x128x32_S32U8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x128x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x128x32_S32U8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x192x32_S32U8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x192x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x192x32_S32U8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x256x32_S32U8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x256x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x256x32_S32U8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x8x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x8x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x8x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x8x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x16x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x16x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x16x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x16x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x32x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x32x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x32x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x32x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x64x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x64x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x64x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x64x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x96x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x96x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x96x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x96x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x128x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x128x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x128x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x128x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x192x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x192x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x192x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x192x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x256x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x256x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x256x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x256x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x8x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x8x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x8x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x8x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x16x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x16x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x16x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x16x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x32x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x32x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x32x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x32x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x64x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x64x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x64x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x64x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x96x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x96x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x96x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x96x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x128x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x128x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x128x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x128x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x192x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x192x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x192x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x192x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x256x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x256x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x256x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x256x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x8x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x8x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x8x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x8x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x16x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x16x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x16x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x16x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x32x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x32x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x32x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x32x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x64x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x64x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x64x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x64x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x96x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x96x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x96x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x96x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x128x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x128x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x128x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x128x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x192x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x192x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x192x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x192x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x256x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x256x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x256x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x256x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x8x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x8x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x8x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x8x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x8x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x16x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x16x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x16x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x16x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x16x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x32x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x32x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x32x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x32x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x32x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x64x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x64x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x64x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x64x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x64x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x96x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x96x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x96x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x96x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x96x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x128x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x128x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x128x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x128x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x128x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x192x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x192x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x192x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x192x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x192x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x256x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x256x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x256x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x256x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x256x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // end namespace cute + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +#include "mma_traits_sm90_gmma_ext.hpp" +#endif diff --git a/include/cute/atom/mma_traits_sm90_gmma_ext.hpp b/include/cute/atom/mma_traits_sm90_gmma_ext.hpp new file mode 100644 index 0000000000..15e2412c87 --- /dev/null +++ b/include/cute/atom/mma_traits_sm90_gmma_ext.hpp @@ -0,0 +1,20116 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include + +namespace cute { + +namespace SM90::GMMA { + +using CLayout_64x24 = CLayout_64xN< 24>; +using CLayout_64x40 = CLayout_64xN< 40>; +using CLayout_64x48 = CLayout_64xN< 48>; +using CLayout_64x56 = CLayout_64xN< 56>; +using CLayout_64x72 = CLayout_64xN< 72>; +using CLayout_64x80 = CLayout_64xN< 80>; +using CLayout_64x88 = CLayout_64xN< 88>; +using CLayout_64x104 = CLayout_64xN<104>; +using CLayout_64x112 = CLayout_64xN<112>; +using CLayout_64x120 = CLayout_64xN<120>; +using CLayout_64x136 = CLayout_64xN<136>; +using CLayout_64x144 = CLayout_64xN<144>; +using CLayout_64x152 = CLayout_64xN<152>; +using CLayout_64x160 = CLayout_64xN<160>; +using CLayout_64x168 = CLayout_64xN<168>; +using CLayout_64x176 = CLayout_64xN<176>; +using CLayout_64x184 = CLayout_64xN<184>; +using CLayout_64x200 = CLayout_64xN<200>; +using CLayout_64x208 = CLayout_64xN<208>; +using CLayout_64x216 = CLayout_64xN<216>; +using CLayout_64x224 = CLayout_64xN<224>; +using CLayout_64x232 = CLayout_64xN<232>; +using CLayout_64x240 = CLayout_64xN<240>; +using CLayout_64x248 = CLayout_64xN<248>; + +} + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x16_F16F16F16_SS = SM90::GMMA::MMA_64x24x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 24, 16>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x16_F16F16F16_RS = SM90::GMMA::MMA_64x24x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 24, 16>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x16_F16F16F16_SS = SM90::GMMA::MMA_64x40x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 40, 16>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x16_F16F16F16_RS = SM90::GMMA::MMA_64x40x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 40, 16>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x16_F16F16F16_SS = SM90::GMMA::MMA_64x48x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 48, 16>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x16_F16F16F16_RS = SM90::GMMA::MMA_64x48x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 48, 16>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x16_F16F16F16_SS = SM90::GMMA::MMA_64x56x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 56, 16>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x16_F16F16F16_RS = SM90::GMMA::MMA_64x56x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 56, 16>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x16_F16F16F16_SS = SM90::GMMA::MMA_64x72x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 72, 16>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x16_F16F16F16_RS = SM90::GMMA::MMA_64x72x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 72, 16>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x16_F16F16F16_SS = SM90::GMMA::MMA_64x80x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 80, 16>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x16_F16F16F16_RS = SM90::GMMA::MMA_64x80x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 80, 16>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x16_F16F16F16_SS = SM90::GMMA::MMA_64x88x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 88, 16>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x16_F16F16F16_RS = SM90::GMMA::MMA_64x88x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 88, 16>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x16_F16F16F16_SS = SM90::GMMA::MMA_64x104x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<104, 16>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x16_F16F16F16_RS = SM90::GMMA::MMA_64x104x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<104, 16>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x16_F16F16F16_SS = SM90::GMMA::MMA_64x112x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<112, 16>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x16_F16F16F16_RS = SM90::GMMA::MMA_64x112x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<112, 16>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x16_F16F16F16_SS = SM90::GMMA::MMA_64x120x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<120, 16>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x16_F16F16F16_RS = SM90::GMMA::MMA_64x120x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<120, 16>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x16_F16F16F16_SS = SM90::GMMA::MMA_64x136x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<136, 16>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x16_F16F16F16_RS = SM90::GMMA::MMA_64x136x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<136, 16>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x16_F16F16F16_SS = SM90::GMMA::MMA_64x144x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<144, 16>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x16_F16F16F16_RS = SM90::GMMA::MMA_64x144x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<144, 16>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x16_F16F16F16_SS = SM90::GMMA::MMA_64x152x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<152, 16>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x16_F16F16F16_RS = SM90::GMMA::MMA_64x152x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<152, 16>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x16_F16F16F16_SS = SM90::GMMA::MMA_64x160x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<160, 16>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x16_F16F16F16_RS = SM90::GMMA::MMA_64x160x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<160, 16>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x16_F16F16F16_SS = SM90::GMMA::MMA_64x168x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<168, 16>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x16_F16F16F16_RS = SM90::GMMA::MMA_64x168x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<168, 16>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x16_F16F16F16_SS = SM90::GMMA::MMA_64x176x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<176, 16>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x16_F16F16F16_RS = SM90::GMMA::MMA_64x176x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<176, 16>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x16_F16F16F16_SS = SM90::GMMA::MMA_64x184x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<184, 16>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x16_F16F16F16_RS = SM90::GMMA::MMA_64x184x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<184, 16>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x16_F16F16F16_SS = SM90::GMMA::MMA_64x200x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<200, 16>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x16_F16F16F16_RS = SM90::GMMA::MMA_64x200x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<200, 16>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x16_F16F16F16_SS = SM90::GMMA::MMA_64x208x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<208, 16>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x16_F16F16F16_RS = SM90::GMMA::MMA_64x208x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<208, 16>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x16_F16F16F16_SS = SM90::GMMA::MMA_64x216x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<216, 16>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x16_F16F16F16_RS = SM90::GMMA::MMA_64x216x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<216, 16>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x16_F16F16F16_SS = SM90::GMMA::MMA_64x224x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<224, 16>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x16_F16F16F16_RS = SM90::GMMA::MMA_64x224x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<224, 16>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x16_F16F16F16_SS = SM90::GMMA::MMA_64x232x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<232, 16>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x16_F16F16F16_RS = SM90::GMMA::MMA_64x232x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<232, 16>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x16_F16F16F16_SS = SM90::GMMA::MMA_64x240x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<240, 16>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x16_F16F16F16_RS = SM90::GMMA::MMA_64x240x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<240, 16>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x16_F16F16F16_SS = SM90::GMMA::MMA_64x248x16_F16F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<248, 16>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x16_F16F16F16_RS = SM90::GMMA::MMA_64x248x16_F16F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<248, 16>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x16_F32F16F16_SS = SM90::GMMA::MMA_64x24x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 24, 16>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x16_F32F16F16_RS = SM90::GMMA::MMA_64x24x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 24, 16>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x16_F32F16F16_SS = SM90::GMMA::MMA_64x40x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 40, 16>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x16_F32F16F16_RS = SM90::GMMA::MMA_64x40x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 40, 16>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x16_F32F16F16_SS = SM90::GMMA::MMA_64x48x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 48, 16>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x16_F32F16F16_RS = SM90::GMMA::MMA_64x48x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 48, 16>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x16_F32F16F16_SS = SM90::GMMA::MMA_64x56x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 56, 16>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x16_F32F16F16_RS = SM90::GMMA::MMA_64x56x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 56, 16>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x16_F32F16F16_SS = SM90::GMMA::MMA_64x72x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 72, 16>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x16_F32F16F16_RS = SM90::GMMA::MMA_64x72x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 72, 16>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x16_F32F16F16_SS = SM90::GMMA::MMA_64x80x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 80, 16>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x16_F32F16F16_RS = SM90::GMMA::MMA_64x80x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 80, 16>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x16_F32F16F16_SS = SM90::GMMA::MMA_64x88x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 88, 16>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x16_F32F16F16_RS = SM90::GMMA::MMA_64x88x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 88, 16>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x16_F32F16F16_SS = SM90::GMMA::MMA_64x104x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<104, 16>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x16_F32F16F16_RS = SM90::GMMA::MMA_64x104x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<104, 16>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x16_F32F16F16_SS = SM90::GMMA::MMA_64x112x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<112, 16>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x16_F32F16F16_RS = SM90::GMMA::MMA_64x112x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<112, 16>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x16_F32F16F16_SS = SM90::GMMA::MMA_64x120x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<120, 16>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x16_F32F16F16_RS = SM90::GMMA::MMA_64x120x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<120, 16>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x16_F32F16F16_SS = SM90::GMMA::MMA_64x136x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<136, 16>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x16_F32F16F16_RS = SM90::GMMA::MMA_64x136x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<136, 16>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x16_F32F16F16_SS = SM90::GMMA::MMA_64x144x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<144, 16>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x16_F32F16F16_RS = SM90::GMMA::MMA_64x144x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<144, 16>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x16_F32F16F16_SS = SM90::GMMA::MMA_64x152x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<152, 16>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x16_F32F16F16_RS = SM90::GMMA::MMA_64x152x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<152, 16>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x16_F32F16F16_SS = SM90::GMMA::MMA_64x160x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<160, 16>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x16_F32F16F16_RS = SM90::GMMA::MMA_64x160x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<160, 16>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x16_F32F16F16_SS = SM90::GMMA::MMA_64x168x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<168, 16>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x16_F32F16F16_RS = SM90::GMMA::MMA_64x168x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<168, 16>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x16_F32F16F16_SS = SM90::GMMA::MMA_64x176x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<176, 16>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x16_F32F16F16_RS = SM90::GMMA::MMA_64x176x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<176, 16>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x16_F32F16F16_SS = SM90::GMMA::MMA_64x184x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<184, 16>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x16_F32F16F16_RS = SM90::GMMA::MMA_64x184x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<184, 16>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x16_F32F16F16_SS = SM90::GMMA::MMA_64x200x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<200, 16>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x16_F32F16F16_RS = SM90::GMMA::MMA_64x200x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<200, 16>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x16_F32F16F16_SS = SM90::GMMA::MMA_64x208x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<208, 16>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x16_F32F16F16_RS = SM90::GMMA::MMA_64x208x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<208, 16>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x16_F32F16F16_SS = SM90::GMMA::MMA_64x216x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<216, 16>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x16_F32F16F16_RS = SM90::GMMA::MMA_64x216x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<216, 16>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x16_F32F16F16_SS = SM90::GMMA::MMA_64x224x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<224, 16>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x16_F32F16F16_RS = SM90::GMMA::MMA_64x224x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<224, 16>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x16_F32F16F16_SS = SM90::GMMA::MMA_64x232x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<232, 16>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x16_F32F16F16_RS = SM90::GMMA::MMA_64x232x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<232, 16>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x16_F32F16F16_SS = SM90::GMMA::MMA_64x240x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<240, 16>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x16_F32F16F16_RS = SM90::GMMA::MMA_64x240x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<240, 16>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x16_F32F16F16_SS = SM90::GMMA::MMA_64x248x16_F32F16F16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<248, 16>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x16_F32F16F16_RS = SM90::GMMA::MMA_64x248x16_F32F16F16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = half_t; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<248, 16>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x24x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 24, 16>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x24x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 24, 16>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x40x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 40, 16>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x40x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 40, 16>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x48x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 48, 16>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x48x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 48, 16>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x56x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 56, 16>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x56x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 56, 16>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x72x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 72, 16>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x72x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 72, 16>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x80x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 80, 16>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x80x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 80, 16>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x88x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout< 88, 16>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x88x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout< 88, 16>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x104x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<104, 16>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x104x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<104, 16>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x112x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<112, 16>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x112x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<112, 16>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x120x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<120, 16>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x120x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<120, 16>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x136x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<136, 16>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x136x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<136, 16>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x144x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<144, 16>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x144x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<144, 16>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x152x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<152, 16>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x152x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<152, 16>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x160x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<160, 16>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x160x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<160, 16>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x168x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<168, 16>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x168x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<168, 16>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x176x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<176, 16>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x176x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<176, 16>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x184x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<184, 16>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x184x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<184, 16>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x200x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<200, 16>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x200x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<200, 16>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x208x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<208, 16>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x208x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<208, 16>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x216x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<216, 16>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x216x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<216, 16>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x224x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<224, 16>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x224x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<224, 16>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x232x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<232, 16>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x232x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<232, 16>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x240x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<240, 16>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x240x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<240, 16>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x16_F32BF16BF16_SS = SM90::GMMA::MMA_64x248x16_F32BF16BF16_SS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using BLayout = GMMA::ABLayout<248, 16>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::Major tnspA, + GMMA::Major tnspB, + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x16_F32BF16BF16_RS = SM90::GMMA::MMA_64x248x16_F32BF16BF16_RS; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = bfloat16_t; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using BLayout = GMMA::ABLayout<248, 16>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x24x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout< 24, 8>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x24x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout< 24, 8>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x40x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout< 40, 8>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x40x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout< 40, 8>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x48x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout< 48, 8>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x48x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout< 48, 8>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x56x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout< 56, 8>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x56x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout< 56, 8>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x72x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout< 72, 8>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x72x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout< 72, 8>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x80x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout< 80, 8>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x80x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout< 80, 8>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x88x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout< 88, 8>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x88x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout< 88, 8>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x104x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<104, 8>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x104x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<104, 8>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x112x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<112, 8>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x112x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<112, 8>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x120x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<120, 8>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x120x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<120, 8>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x136x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<136, 8>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x136x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<136, 8>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x144x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<144, 8>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x144x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<144, 8>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x152x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<152, 8>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x152x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<152, 8>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x160x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<160, 8>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x160x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<160, 8>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x168x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<168, 8>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x168x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<168, 8>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x176x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<176, 8>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x176x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<176, 8>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x184x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<184, 8>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x184x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<184, 8>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x200x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<200, 8>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x200x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<200, 8>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x208x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<208, 8>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x208x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<208, 8>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x216x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<216, 8>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x216x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<216, 8>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x224x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<224, 8>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x224x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<224, 8>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x232x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<232, 8>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x232x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<232, 8>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x240x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<240, 8>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x240x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<240, 8>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x8_F32TF32TF32_SS_TN = SM90::GMMA::MMA_64x248x8_F32TF32TF32_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 8>; + using BLayout = GMMA::ABLayout<248, 8>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x8_F32TF32TF32_RS_TN = SM90::GMMA::MMA_64x248x8_F32TF32TF32_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = tfloat32_t; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_8>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x8; + using BLayout = GMMA::ABLayout<248, 8>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x24x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x24x32_S32S8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x24x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x24x32_S32S8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x48x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x48x32_S32S8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x48x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x48x32_S32S8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x80x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x80x32_S32S8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x80x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x80x32_S32S8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x112x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x112x32_S32S8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x112x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x112x32_S32S8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x144x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x144x32_S32S8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x144x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x144x32_S32S8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x160x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x160x32_S32S8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x160x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x160x32_S32S8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x176x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x176x32_S32S8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x176x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x176x32_S32S8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x208x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x208x32_S32S8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x208x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x208x32_S32S8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x224x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x224x32_S32S8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x224x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x224x32_S32S8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x240x32_S32S8S8_SS_TN = SM90::GMMA::MMA_64x240x32_S32S8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x240x32_S32S8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x240x32_S32S8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x24x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x24x32_S32S8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x24x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x24x32_S32S8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x48x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x48x32_S32S8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x48x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x48x32_S32S8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x80x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x80x32_S32S8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x80x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x80x32_S32S8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x112x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x112x32_S32S8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x112x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x112x32_S32S8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x144x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x144x32_S32S8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x144x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x144x32_S32S8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x160x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x160x32_S32S8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x160x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x160x32_S32S8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x176x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x176x32_S32S8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x176x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x176x32_S32S8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x208x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x208x32_S32S8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x208x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x208x32_S32S8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x224x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x224x32_S32S8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x224x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x224x32_S32S8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x240x32_S32S8S8_RS_TN = SM90::GMMA::MMA_64x240x32_S32S8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x240x32_S32S8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x240x32_S32S8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x24x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x24x32_S32S8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x24x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x24x32_S32S8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x48x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x48x32_S32S8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x48x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x48x32_S32S8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x80x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x80x32_S32S8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x80x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x80x32_S32S8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x112x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x112x32_S32S8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x112x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x112x32_S32S8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x144x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x144x32_S32S8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x144x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x144x32_S32S8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x160x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x160x32_S32S8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x160x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x160x32_S32S8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x176x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x176x32_S32S8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x176x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x176x32_S32S8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x208x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x208x32_S32S8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x208x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x208x32_S32S8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x224x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x224x32_S32S8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x224x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x224x32_S32S8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x240x32_S32S8U8_SS_TN = SM90::GMMA::MMA_64x240x32_S32S8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x240x32_S32S8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x240x32_S32S8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x24x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x24x32_S32S8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x24x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x24x32_S32S8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x48x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x48x32_S32S8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x48x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x48x32_S32S8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x80x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x80x32_S32S8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x80x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x80x32_S32S8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x112x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x112x32_S32S8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x112x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x112x32_S32S8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x144x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x144x32_S32S8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x144x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x144x32_S32S8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x160x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x160x32_S32S8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x160x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x160x32_S32S8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x176x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x176x32_S32S8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x176x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x176x32_S32S8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x208x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x208x32_S32S8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x208x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x208x32_S32S8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x224x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x224x32_S32S8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x224x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x224x32_S32S8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x240x32_S32S8U8_RS_TN = SM90::GMMA::MMA_64x240x32_S32S8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x240x32_S32S8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x240x32_S32S8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = int8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x24x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x24x32_S32U8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x24x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x24x32_S32U8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x48x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x48x32_S32U8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x48x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x48x32_S32U8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x80x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x80x32_S32U8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x80x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x80x32_S32U8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x112x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x112x32_S32U8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x112x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x112x32_S32U8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x144x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x144x32_S32U8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x144x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x144x32_S32U8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x160x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x160x32_S32U8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x160x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x160x32_S32U8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x176x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x176x32_S32U8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x176x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x176x32_S32U8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x208x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x208x32_S32U8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x208x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x208x32_S32U8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x224x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x224x32_S32U8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x224x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x224x32_S32U8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x240x32_S32U8S8_SS_TN = SM90::GMMA::MMA_64x240x32_S32U8S8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x240x32_S32U8S8_SS_TN_SATURATE = SM90::GMMA::MMA_64x240x32_S32U8S8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x24x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x24x32_S32U8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x24x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x24x32_S32U8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x48x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x48x32_S32U8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x48x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x48x32_S32U8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x80x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x80x32_S32U8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x80x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x80x32_S32U8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x112x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x112x32_S32U8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x112x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x112x32_S32U8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x144x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x144x32_S32U8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x144x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x144x32_S32U8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x160x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x160x32_S32U8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x160x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x160x32_S32U8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x176x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x176x32_S32U8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x176x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x176x32_S32U8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x208x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x208x32_S32U8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x208x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x208x32_S32U8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x224x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x224x32_S32U8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x224x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x224x32_S32U8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x240x32_S32U8S8_RS_TN = SM90::GMMA::MMA_64x240x32_S32U8S8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x240x32_S32U8S8_RS_TN_SATURATE = SM90::GMMA::MMA_64x240x32_S32U8S8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x24x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x24x32_S32U8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x24x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x24x32_S32U8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x48x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x48x32_S32U8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x48x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x48x32_S32U8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x80x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x80x32_S32U8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x80x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x80x32_S32U8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x112x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x112x32_S32U8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x112x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x112x32_S32U8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x144x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x144x32_S32U8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x144x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x144x32_S32U8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x160x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x160x32_S32U8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x160x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x160x32_S32U8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x176x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x176x32_S32U8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x176x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x176x32_S32U8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x208x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x208x32_S32U8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x208x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x208x32_S32U8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x224x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x224x32_S32U8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x224x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x224x32_S32U8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x240x32_S32U8U8_SS_TN = SM90::GMMA::MMA_64x240x32_S32U8U8_SS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x240x32_S32U8U8_SS_TN_SATURATE = SM90::GMMA::MMA_64x240x32_S32U8U8_SS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x24x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x24x32_S32U8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x24x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x24x32_S32U8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x48x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x48x32_S32U8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x48x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x48x32_S32U8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x80x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x80x32_S32U8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x80x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x80x32_S32U8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x112x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x112x32_S32U8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x112x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x112x32_S32U8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x144x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x144x32_S32U8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x144x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x144x32_S32U8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x160x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x160x32_S32U8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x160x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x160x32_S32U8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x176x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x176x32_S32U8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x176x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x176x32_S32U8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x208x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x208x32_S32U8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x208x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x208x32_S32U8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x224x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x224x32_S32U8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x224x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x224x32_S32U8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x240x32_S32U8U8_RS_TN = SM90::GMMA::MMA_64x240x32_S32U8U8_RS_TN; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +using SM90_64x240x32_S32U8U8_RS_TN_SATURATE = SM90::GMMA::MMA_64x240x32_S32U8U8_RS_TN_SATURATE; + +template <> +struct MMA_Traits +{ + using ValTypeD = int32_t; + using ValTypeA = uint8_t; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x24x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x24x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x24x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x24x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x40x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x40x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x40x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x40x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x48x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x48x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x48x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x48x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x56x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x56x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x56x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x56x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x72x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x72x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x72x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x72x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x80x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x80x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x80x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x80x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x88x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x88x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x88x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x88x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x104x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x104x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x104x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x104x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x112x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x112x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x112x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x112x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x120x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x120x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x120x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x120x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x136x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x136x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x136x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x136x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x144x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x144x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x144x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x144x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x152x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x152x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x152x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x152x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x160x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x160x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x160x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x160x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x168x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x168x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x168x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x168x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x176x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x176x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x176x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x176x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x184x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x184x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x184x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x184x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x200x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x200x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x200x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x200x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x208x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x208x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x208x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x208x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x216x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x216x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x216x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x216x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x224x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x224x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x224x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x224x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x232x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x232x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x232x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x232x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x240x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x240x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x240x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x240x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x32_F16E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x248x32_F16E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x32_F16E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x248x32_F16E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x32_F32E4M3E4M3_SS_TN = SM90::GMMA::MMA_64x248x32_F32E4M3E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x32_F32E4M3E4M3_RS_TN = SM90::GMMA::MMA_64x248x32_F32E4M3E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x24x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x24x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x24x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x24x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x40x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x40x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x40x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x40x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x48x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x48x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x48x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x48x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x56x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x56x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x56x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x56x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x72x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x72x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x72x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x72x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x80x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x80x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x80x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x80x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x88x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x88x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x88x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x88x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x104x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x104x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x104x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x104x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x112x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x112x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x112x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x112x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x120x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x120x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x120x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x120x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x136x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x136x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x136x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x136x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x144x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x144x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x144x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x144x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x152x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x152x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x152x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x152x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x160x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x160x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x160x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x160x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x168x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x168x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x168x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x168x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x176x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x176x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x176x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x176x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x184x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x184x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x184x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x184x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x200x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x200x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x200x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x200x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x208x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x208x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x208x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x208x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x216x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x216x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x216x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x216x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x224x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x224x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x224x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x224x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x232x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x232x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x232x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x232x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x240x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x240x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x240x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x240x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x32_F16E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x248x32_F16E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x32_F16E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x248x32_F16E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x32_F32E4M3E5M2_SS_TN = SM90::GMMA::MMA_64x248x32_F32E4M3E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x32_F32E4M3E5M2_RS_TN = SM90::GMMA::MMA_64x248x32_F32E4M3E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e4m3_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x24x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x24x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x24x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x24x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x40x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x40x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x40x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x40x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x48x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x48x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x48x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x48x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x56x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x56x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x56x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x56x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x72x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x72x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x72x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x72x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x80x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x80x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x80x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x80x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x88x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x88x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x88x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x88x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x104x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x104x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x104x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x104x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x112x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x112x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x112x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x112x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x120x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x120x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x120x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x120x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x136x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x136x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x136x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x136x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x144x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x144x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x144x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x144x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x152x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x152x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x152x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x152x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x160x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x160x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x160x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x160x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x168x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x168x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x168x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x168x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x176x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x176x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x176x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x176x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x184x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x184x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x184x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x184x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x200x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x200x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x200x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x200x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x208x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x208x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x208x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x208x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x216x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x216x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x216x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x216x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x224x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x224x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x224x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x224x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x232x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x232x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x232x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x232x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x240x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x240x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x240x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x240x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x32_F16E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x248x32_F16E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x32_F16E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x248x32_F16E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x32_F32E5M2E4M3_SS_TN = SM90::GMMA::MMA_64x248x32_F32E5M2E4M3_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x32_F32E5M2E4M3_RS_TN = SM90::GMMA::MMA_64x248x32_F32E5M2E4M3_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x24x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x24x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x24x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x24x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x24x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x40x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x40x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x40x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x40x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x40x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x48x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x48x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x48x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x48x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x48x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x56x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x56x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x56x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x56x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x56x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x72x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x72x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x72x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x72x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x72x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x80x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x80x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x80x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x80x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x80x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x88x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x88x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x88x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x88x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x88x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x104x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x104x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x104x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x104x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x104x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x112x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x112x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x112x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x112x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x112x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x120x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x120x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x120x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x120x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x120x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x136x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x136x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x136x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x136x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x136x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x144x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x144x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x144x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x144x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x144x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x152x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x152x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x152x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x152x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x152x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x160x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x160x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x160x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x160x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x160x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x168x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x168x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x168x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x168x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x168x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x176x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x176x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x176x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x176x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x176x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x184x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x184x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x184x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x184x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x184x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x200x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x200x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x200x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x200x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x200x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x208x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x208x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x208x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x208x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x208x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x216x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x216x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x216x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x216x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x216x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x224x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x224x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x224x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x224x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x224x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x232x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x232x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x232x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x232x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x232x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x240x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x240x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x240x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x240x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x240x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x32_F16E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x248x32_F16E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x32_F16E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x248x32_F16E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x32_F32E5M2E5M2_SS_TN = SM90::GMMA::MMA_64x248x32_F32E5M2E5M2_SS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + GMMA::ScaleIn scaleA = GMMA::ScaleIn::One, + GMMA::ScaleIn scaleB = GMMA::ScaleIn::One +> +using SM90_64x248x32_F32E5M2E5M2_RS_TN = SM90::GMMA::MMA_64x248x32_F32E5M2E5M2_RS_TN; + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = float_e5m2_t; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // end namespace cute diff --git a/include/cute/atom/mma_traits_sm90_gmma_sparse.hpp b/include/cute/atom/mma_traits_sm90_gmma_sparse.hpp new file mode 100644 index 0000000000..161dc7ecf0 --- /dev/null +++ b/include/cute/atom/mma_traits_sm90_gmma_sparse.hpp @@ -0,0 +1,7738 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include // cute::smem_sparse_ptr_flag +#include // cute::Swizzle +#include // cute::Tensor +#include // cute::LayoutType +#include // cute::SM90::SPARSE::GMMA_64x8x32_F16F16F16_SS, etc +#include // cute::GMMA::Layout_* +#include // cute::MMA_Traits +#include // cute::ComposedLayout +#include // cute::is_static + +namespace cute { + +namespace SM90::GMMA { + +/////////////////////////////////////////// +// Common layouts for GMMA Shared Memory // +/////////////////////////////////////////// + +// M|N-major layouts in units of Type and sparsity factor S +template +using Layout_MN_INTER_SpAtom = ComposedLayout, smem_sparse_ptr_flag_bits>, + decltype(blocked_product(Layout>>{}, Layout_MN_INTER_Atom{}.layout_b()))>; +template +using Layout_MN_SW32_SpAtom = ComposedLayout, smem_sparse_ptr_flag_bits>, + decltype(blocked_product(Layout>>{}, Layout_MN_SW32_Atom{}.layout_b()))>; +template +using Layout_MN_SW64_SpAtom = ComposedLayout, smem_sparse_ptr_flag_bits>, + decltype(blocked_product(Layout>>{}, Layout_MN_SW64_Atom{}.layout_b()))>; +template +using Layout_MN_SW128_SpAtom = ComposedLayout, smem_sparse_ptr_flag_bits>, + decltype(blocked_product(Layout>>{}, Layout_MN_SW128_Atom{}.layout_b()))>; + +// K-major layouts in units of Type and sparsity factor S +template +using Layout_K_INTER_SpAtom = ComposedLayout, smem_sparse_ptr_flag_bits>, + decltype(blocked_product(Layout>>{}, Layout_K_INTER_Atom{}.layout_b()))>; +template +using Layout_K_SW32_SpAtom = ComposedLayout, smem_sparse_ptr_flag_bits>, + decltype(blocked_product(Layout>>{}, Layout_K_SW32_Atom{}.layout_b()))>; +template +using Layout_K_SW64_SpAtom = ComposedLayout, smem_sparse_ptr_flag_bits>, + decltype(blocked_product(Layout>>{}, Layout_K_SW64_Atom{}.layout_b()))>; +template +using Layout_K_SW128_SpAtom = ComposedLayout, smem_sparse_ptr_flag_bits>, + decltype(blocked_product(Layout>>{}, Layout_K_SW128_Atom{}.layout_b()))>; + +// With GMMA::Major param +template +using Layout_INTER_SpAtom = typename conditional, + Layout_K_INTER_SpAtom>::type; +template +using Layout_SW32_SpAtom = typename conditional, + Layout_K_SW32_SpAtom>::type; +template +using Layout_SW64_SpAtom = typename conditional, + Layout_K_SW64_SpAtom>::type; +template +using Layout_SW128_SpAtom = typename conditional, + Layout_K_SW128_SpAtom>::type; + +/////////////////////////////////////////////////////////////////////////////// +// Higher level GMMA Descriptor utilities +/////////////////////////////////////////////////////////////////////////////// + +template +struct sparse_smem_desc : DescriptorIterator {}; + +} // end namespace SM90::GMMA + +// Customization point for creating a cute::GMMAsparse_smem_desc Tensor +template +struct MakeTensor> +{ + // Note that this is the exact same as cute::GMMAsmem_desc above, plus additional static checks. + template + CUTE_HOST_DEVICE constexpr auto + operator()(Tensor const& smem_tensor) + { + static_assert(is_smem::value, "Expected SMEM Tensor to construct a GMMA Desc Tensor"); + static_assert(is_sparse::value, "Expected sparse value_type."); + static_assert(is_sparse_ptr::value, "Expected sparse iter."); + return make_tensor(SM90::GMMA::DescriptorIterator{SM90::GMMA::make_gmma_desc(tensor<0>(smem_tensor))}, + replace<0>(recast(smem_tensor).layout(), Layout<_1,_0>{})); + } +}; + +/////////////////////////////////////////////////////////////////////////////// +//////////////////////////// MMA_TRAITS /////////////////////////////////////// +/////////////////////////////////////////////////////////////////////////////// + +namespace SM90::GMMA { + +// Metadata layouts +using ELayout_64x64 = Layout, Shape <_32>>, + Stride, Stride<_64>>>; + +using ELayout_64x32 = Layout, Shape <_16,_2>>, + Stride, Stride<_64,_8>>>; + +using ELayout_64x16 = Layout, Shape < _8,_2>>, + Stride, Stride<_64,_8>>>; + +} // namespace SM90::GMMA + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace SM90::GMMA::SPARSE { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTE_HOST_DEVICE constexpr void +mma_unpack(MMA_Traits const& traits, + Tensor & D, + Tensor const& A_zipped, + Tensor const& B, + Tensor const& C) +{ + static_assert(is_rmem_v, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem_v, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem_v, "Expected registers in MMA_Atom::call"); + static_assert(is_rmem_v, "Expected registers in MMA_Atom::call"); + + using DRegisters = typename MMAOp::DRegisters; + using ARegisters = typename MMAOp::ARegisters; + using ERegisters = typename MMAOp::ERegisters; + using BRegisters = typename MMAOp::BRegisters; + using CRegisters = typename MMAOp::CRegisters; + + // Register value types from the MMAOp register arrays + using RegTypeD = typename remove_extent::type; + using RegTypeA = typename remove_extent::type; + using RegTypeE = typename remove_extent::type; + using RegTypeB = typename remove_extent::type; + using RegTypeC = typename remove_extent::type; + + constexpr int RegNumA = extent::value; + constexpr int RegNumE = extent::value; + constexpr int RegNumB = extent::value; + constexpr int RegNumC = extent::value; + + auto [A, E] = unzip_tensor(A_zipped); + Tensor rA = recast(A); + Tensor rE = recast(E); + Tensor rB = recast(B); + + CUTE_STATIC_ASSERT_V(size(rA) == Int{}); + CUTE_STATIC_ASSERT_V(size(rE) == Int{}); + CUTE_STATIC_ASSERT_V(size(rB) == Int{}); + + static_assert(is_same::value, "GMMA DRegisters must have void type."); + static_assert(is_same::value, "GMMA C and D value_type must match."); + static_assert(is_same::value, "GMMA C and D layouts must match."); + + Tensor rC = recast(D); // NOTE: D and C are same, so use mutable D + + CUTE_STATIC_ASSERT_V(size(rC) == Int{}); + + detail::explode(MMAOp::fma, + rA, make_int_sequence{}, + rB, make_int_sequence{}, + rC, make_int_sequence{}, + rE, make_int_sequence{}, + &(traits.accumulate_), seq<0>{}); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace SM90::SPARSE + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 8, 32>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 16, 32>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 32, 32>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 64, 32>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 96, 32>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<128, 32>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<192, 32>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<256, 32>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 8, 16>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 8, 16>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 16, 16>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 16, 16>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 32, 16>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 32, 16>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 64, 16>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 64, 16>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 96, 16>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 96, 16>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<128, 16>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<128, 16>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<192, 16>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<192, 16>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<256, 16>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<256, 16>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_8,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 8, 64>; + using CLayout = GMMA::CLayout_64x8; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_16,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 16, 64>; + using CLayout = GMMA::CLayout_64x16; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_32,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 32, 64>; + using CLayout = GMMA::CLayout_64x32; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_64,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 64, 64>; + using CLayout = GMMA::CLayout_64x64; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_96,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 96, 64>; + using CLayout = GMMA::CLayout_64x96; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_128,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<128, 64>; + using CLayout = GMMA::CLayout_64x128; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_192,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<192, 64>; + using CLayout = GMMA::CLayout_64x192; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_256,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<256, 64>; + using CLayout = GMMA::CLayout_64x256; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // end namespace cute + +#if defined(CUTE_SM90_EXTENDED_MMA_SHAPES_ENABLED) +#include "mma_traits_sm90_gmma_sparse_ext.hpp" +#endif diff --git a/include/cute/atom/mma_traits_sm90_gmma_sparse_ext.hpp b/include/cute/atom/mma_traits_sm90_gmma_sparse_ext.hpp new file mode 100644 index 0000000000..3680b7e13f --- /dev/null +++ b/include/cute/atom/mma_traits_sm90_gmma_sparse_ext.hpp @@ -0,0 +1,17335 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include + +namespace cute { + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, half_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = half_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 24, 32>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 40, 32>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 48, 32>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 56, 32>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 72, 32>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 80, 32>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout< 88, 32>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<104, 32>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<112, 32>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<120, 32>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<136, 32>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<144, 32>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<152, 32>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<160, 32>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<168, 32>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<176, 32>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<184, 32>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<200, 32>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<208, 32>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<216, 32>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<224, 32>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<232, 32>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<240, 32>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 32>; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, bfloat16_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = bfloat16_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_32>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x32; + using ELayout = GMMA::ELayout_64x32; + using BLayout = GMMA::ABLayout<248, 32>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 24, 16>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 24, 16>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 40, 16>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 40, 16>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 48, 16>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 48, 16>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 56, 16>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 56, 16>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 72, 16>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 72, 16>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 80, 16>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 80, 16>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 88, 16>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout< 88, 16>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<104, 16>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<104, 16>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<112, 16>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<112, 16>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<120, 16>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<120, 16>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<136, 16>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<136, 16>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<144, 16>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<144, 16>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<152, 16>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<152, 16>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<160, 16>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<160, 16>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<168, 16>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<168, 16>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<176, 16>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<176, 16>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<184, 16>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<184, 16>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<200, 16>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<200, 16>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<208, 16>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<208, 16>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<216, 16>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<216, 16>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<224, 16>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<224, 16>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<232, 16>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<232, 16>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<240, 16>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<240, 16>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 16>; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<248, 16>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, tfloat32_t>; + using ValTypeE = sparse_elem<4, uint8_t>; + using ValTypeB = tfloat32_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_16>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x16; + using ELayout = GMMA::ELayout_64x16; + using BLayout = GMMA::ABLayout<248, 16>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, int8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = int8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = int32_t; + using ValTypeA = sparse_elem<2, uint8_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = uint8_t; + using ValTypeC = int32_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 40, 64>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 40, 64>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 40, 64>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 40, 64>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 56, 64>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 56, 64>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 56, 64>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 56, 64>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 72, 64>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 72, 64>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 72, 64>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 72, 64>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 88, 64>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 88, 64>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 88, 64>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 88, 64>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<104, 64>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<104, 64>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<104, 64>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<104, 64>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<120, 64>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<120, 64>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<120, 64>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<120, 64>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<136, 64>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<136, 64>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<136, 64>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<136, 64>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<152, 64>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<152, 64>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<152, 64>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<152, 64>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<168, 64>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<168, 64>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<168, 64>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<168, 64>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<184, 64>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<184, 64>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<184, 64>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<184, 64>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<200, 64>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<200, 64>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<200, 64>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<200, 64>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<216, 64>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<216, 64>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<216, 64>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<216, 64>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<232, 64>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<232, 64>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<232, 64>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<232, 64>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<248, 64>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<248, 64>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<248, 64>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<248, 64>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 40, 64>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 40, 64>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 40, 64>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 40, 64>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 56, 64>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 56, 64>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 56, 64>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 56, 64>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 72, 64>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 72, 64>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 72, 64>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 72, 64>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 88, 64>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 88, 64>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 88, 64>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 88, 64>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<104, 64>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<104, 64>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<104, 64>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<104, 64>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<120, 64>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<120, 64>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<120, 64>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<120, 64>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<136, 64>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<136, 64>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<136, 64>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<136, 64>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<152, 64>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<152, 64>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<152, 64>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<152, 64>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<168, 64>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<168, 64>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<168, 64>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<168, 64>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<184, 64>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<184, 64>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<184, 64>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<184, 64>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<200, 64>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<200, 64>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<200, 64>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<200, 64>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<216, 64>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<216, 64>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<216, 64>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<216, 64>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<232, 64>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<232, 64>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<232, 64>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<232, 64>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<248, 64>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<248, 64>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<248, 64>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e4m3_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<248, 64>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 40, 64>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 40, 64>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 40, 64>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 40, 64>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 56, 64>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 56, 64>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 56, 64>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 56, 64>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 72, 64>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 72, 64>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 72, 64>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 72, 64>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 88, 64>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 88, 64>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 88, 64>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 88, 64>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<104, 64>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<104, 64>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<104, 64>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<104, 64>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<120, 64>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<120, 64>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<120, 64>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<120, 64>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<136, 64>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<136, 64>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<136, 64>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<136, 64>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<152, 64>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<152, 64>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<152, 64>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<152, 64>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<168, 64>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<168, 64>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<168, 64>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<168, 64>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<184, 64>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<184, 64>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<184, 64>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<184, 64>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<200, 64>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<200, 64>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<200, 64>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<200, 64>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<216, 64>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<216, 64>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<216, 64>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<216, 64>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<232, 64>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<232, 64>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<232, 64>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<232, 64>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<248, 64>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<248, 64>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<248, 64>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e4m3_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<248, 64>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_24,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 24, 64>; + using CLayout = GMMA::CLayout_64x24; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 40, 64>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 40, 64>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 40, 64>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_40,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 40, 64>; + using CLayout = GMMA::CLayout_64x40; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_48,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 48, 64>; + using CLayout = GMMA::CLayout_64x48; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 56, 64>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 56, 64>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 56, 64>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_56,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 56, 64>; + using CLayout = GMMA::CLayout_64x56; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 72, 64>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 72, 64>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 72, 64>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_72,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 72, 64>; + using CLayout = GMMA::CLayout_64x72; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_80,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 80, 64>; + using CLayout = GMMA::CLayout_64x80; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 88, 64>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 88, 64>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 88, 64>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_88,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout< 88, 64>; + using CLayout = GMMA::CLayout_64x88; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<104, 64>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<104, 64>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<104, 64>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_104,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<104, 64>; + using CLayout = GMMA::CLayout_64x104; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_112,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<112, 64>; + using CLayout = GMMA::CLayout_64x112; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<120, 64>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<120, 64>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<120, 64>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_120,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<120, 64>; + using CLayout = GMMA::CLayout_64x120; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<136, 64>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<136, 64>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<136, 64>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_136,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<136, 64>; + using CLayout = GMMA::CLayout_64x136; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_144,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<144, 64>; + using CLayout = GMMA::CLayout_64x144; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<152, 64>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<152, 64>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<152, 64>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_152,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<152, 64>; + using CLayout = GMMA::CLayout_64x152; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_160,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<160, 64>; + using CLayout = GMMA::CLayout_64x160; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<168, 64>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<168, 64>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<168, 64>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_168,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<168, 64>; + using CLayout = GMMA::CLayout_64x168; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_176,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<176, 64>; + using CLayout = GMMA::CLayout_64x176; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<184, 64>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<184, 64>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<184, 64>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_184,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<184, 64>; + using CLayout = GMMA::CLayout_64x184; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<200, 64>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<200, 64>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<200, 64>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_200,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<200, 64>; + using CLayout = GMMA::CLayout_64x200; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_208,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<208, 64>; + using CLayout = GMMA::CLayout_64x208; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<216, 64>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<216, 64>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<216, 64>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_216,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<216, 64>; + using CLayout = GMMA::CLayout_64x216; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_224,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<224, 64>; + using CLayout = GMMA::CLayout_64x224; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<232, 64>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<232, 64>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<232, 64>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_232,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<232, 64>; + using CLayout = GMMA::CLayout_64x232; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_240,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<240, 64>; + using CLayout = GMMA::CLayout_64x240; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<248, 64>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = half_t; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = half_t; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<248, 64>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeA = GMMA::smem_desc; + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ABLayout< 64, 64>; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<248, 64>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct MMA_Traits> +{ + using ValTypeD = float; + using ValTypeA = sparse_elem<2, float_e5m2_t>; + using ValTypeE = sparse_elem<8, uint8_t>; + using ValTypeB = float_e5m2_t; + using ValTypeC = float; + + using FrgTypeB = GMMA::smem_desc; + + using Shape_MNK = Shape<_64,_248,_64>; + using ThrID = Layout<_128>; + using ALayout = GMMA::ALayout_64x64; + using ELayout = GMMA::ELayout_64x64; + using BLayout = GMMA::ABLayout<248, 64>; + using CLayout = GMMA::CLayout_64x248; + + GMMA::ScaleOut accumulate_ = GMMA::ScaleOut::One; +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // end namespace cute diff --git a/include/cute/config.hpp b/include/cute/config.hpp new file mode 100644 index 0000000000..792eee90f0 --- /dev/null +++ b/include/cute/config.hpp @@ -0,0 +1,149 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#if defined(__CUDACC__) || defined(_NVHPC_CUDA) +# define CUTE_HOST_DEVICE __forceinline__ __host__ __device__ +# define CUTE_DEVICE __forceinline__ __device__ +# define CUTE_HOST __forceinline__ __host__ +#else +# define CUTE_HOST_DEVICE inline +# define CUTE_DEVICE inline +# define CUTE_HOST inline +#endif // CUTE_HOST_DEVICE, CUTE_DEVICE + +#if defined(__CUDACC_RTC__) +# define CUTE_HOST_RTC CUTE_HOST_DEVICE +#else +# define CUTE_HOST_RTC CUTE_HOST +#endif + +#if !defined(__CUDACC_RTC__) && !defined(__clang__) && \ + (defined(__CUDA_ARCH__) || defined(_NVHPC_CUDA)) +# define CUTE_UNROLL #pragma unroll +# define CUTE_NO_UNROLL #pragma unroll 1 +#elif defined(__CUDACC_RTC__) || defined(__clang__) +# define CUTE_UNROLL _Pragma("unroll") +# define CUTE_NO_UNROLL _Pragma("unroll 1") +#else +# define CUTE_UNROLL +# define CUTE_NO_UNROLL +#endif // CUTE_UNROLL + +#if defined(__CUDA_ARCH__) || defined(_NVHPC_CUDA) +# define CUTE_INLINE_CONSTANT static const __device__ +#else +# define CUTE_INLINE_CONSTANT static constexpr +#endif + +// __grid_constant__ was introduced in CUDA 11.7. +#if ((__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 7))) +# define CUTE_GRID_CONSTANT_SUPPORTED +#endif + +// __grid_constant__ can be enabled only on SM70+. +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700)) +# define CUTE_GRID_CONSTANT_ENABLED +#endif + +#if ! defined(CUTE_GRID_CONSTANT) +# if defined(CUTE_GRID_CONSTANT_SUPPORTED) && defined(CUTE_GRID_CONSTANT_ENABLED) +# define CUTE_GRID_CONSTANT __grid_constant__ +# else +# define CUTE_GRID_CONSTANT +# endif +#endif + +// Some versions of GCC < 11 have trouble deducing that a +// function with "auto" return type and all of its returns in an "if +// constexpr ... else" statement must actually return. Thus, GCC +// emits spurious "missing return statement" build warnings. +// Developers can suppress these warnings by using the +// CUTE_GCC_UNREACHABLE macro, which must be followed by a semicolon. +// It's harmless to use the macro for other GCC versions or other +// compilers, but it has no effect. +#if ! defined(CUTE_GCC_UNREACHABLE) +# if defined(__GNUC__) +# define CUTE_GCC_UNREACHABLE __builtin_unreachable() +# else +# define CUTE_GCC_UNREACHABLE +# endif +#endif + +#if defined(_MSC_VER) +// Provides support for alternative operators 'and', 'or', and 'not' +# include +#endif // _MSC_VER + +#if defined(__CUDACC_RTC__) +# define CUTE_STL_NAMESPACE cuda::std +# define CUTE_STL_NAMESPACE_IS_CUDA_STD +#else +# define CUTE_STL_NAMESPACE std +#endif + +// +// Assertion helpers +// + +#if defined(__CUDACC_RTC__) +# include +#else +# include +#endif + +#define CUTE_STATIC_V(x) decltype(x)::value + +#define CUTE_STATIC_ASSERT static_assert +#define CUTE_STATIC_ASSERT_V(x,...) static_assert(decltype(x)::value, ##__VA_ARGS__) + +// Fail and print a message. Typically used for notification of a compiler misconfiguration. +#if defined(__CUDA_ARCH__) +# define CUTE_INVALID_CONTROL_PATH(x) assert(0 && x); printf(x); __brkpt() +#else +# define CUTE_INVALID_CONTROL_PATH(x) assert(0 && x); printf(x) +#endif + +// +// IO +// + +#if !defined(__CUDACC_RTC__) +# include +# include +# include +#endif + +// +// Debugging utilities +// + +#include diff --git a/include/cute/container/alignment.hpp b/include/cute/container/alignment.hpp new file mode 100644 index 0000000000..52e4cbadd9 --- /dev/null +++ b/include/cute/container/alignment.hpp @@ -0,0 +1,70 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include + +namespace cute +{ + +// Test if a pointer is aligned to N bytes +template +CUTE_HOST_DEVICE constexpr +bool +is_byte_aligned(void const* const ptr) +{ + static_assert(has_single_bit(N), "N must be a power of 2 in alignment check"); + return (reinterpret_cast(ptr) & (N-1)) == 0; +} + +#if defined(__CUDACC__) +# define CUTE_ALIGNAS(n) __align__(n) +#else +# define CUTE_ALIGNAS(n) alignas(n) +#endif + +template +struct aligned_struct {}; + +template struct CUTE_ALIGNAS( 1) aligned_struct< 1, Child> {}; +template struct CUTE_ALIGNAS( 2) aligned_struct< 2, Child> {}; +template struct CUTE_ALIGNAS( 4) aligned_struct< 4, Child> {}; +template struct CUTE_ALIGNAS( 8) aligned_struct< 8, Child> {}; +template struct CUTE_ALIGNAS( 16) aligned_struct< 16, Child> {}; +template struct CUTE_ALIGNAS( 32) aligned_struct< 32, Child> {}; +template struct CUTE_ALIGNAS( 64) aligned_struct< 64, Child> {}; +template struct CUTE_ALIGNAS(128) aligned_struct<128, Child> {}; +template struct CUTE_ALIGNAS(256) aligned_struct<256, Child> {}; + +} // end namespace cute diff --git a/include/cute/container/array.hpp b/include/cute/container/array.hpp new file mode 100644 index 0000000000..9cdcf5f4c2 --- /dev/null +++ b/include/cute/container/array.hpp @@ -0,0 +1,492 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include + +namespace cute +{ + +template +struct array +{ + using element_type = T; + using value_type = remove_cv_t; + using size_type = size_t; + using difference_type = ptrdiff_t; + using reference = element_type&; + using const_reference = const element_type&; + using pointer = element_type*; + using const_pointer = const element_type*; + using iterator = pointer; + using const_iterator = const_pointer; + + CUTE_HOST_DEVICE constexpr + reference operator[](size_type pos) + { + return begin()[pos]; + } + + CUTE_HOST_DEVICE constexpr + const_reference operator[](size_type pos) const + { + return begin()[pos]; + } + + CUTE_HOST_DEVICE constexpr + reference front() + { + return *begin(); + } + + CUTE_HOST_DEVICE constexpr + const_reference front() const + { + return *begin(); + } + + CUTE_HOST_DEVICE constexpr + reference back() + { + // return *rbegin(); + return operator[](N-1); + } + + CUTE_HOST_DEVICE constexpr + const_reference back() const + { + // return *rbegin(); + return operator[](N-1); + } + + CUTE_HOST_DEVICE constexpr + T* data() + { + return __elems_; + } + + CUTE_HOST_DEVICE constexpr + T const* data() const + { + return __elems_; + } + + CUTE_HOST_DEVICE constexpr + iterator begin() + { + return data(); + } + + CUTE_HOST_DEVICE constexpr + const_iterator begin() const + { + return data(); + } + + CUTE_HOST_DEVICE constexpr + const_iterator cbegin() + { + return begin(); + } + + CUTE_HOST_DEVICE constexpr + const_iterator cbegin() const + { + return begin(); + } + + CUTE_HOST_DEVICE constexpr + iterator end() + { + return data() + size(); + } + + CUTE_HOST_DEVICE constexpr + const_iterator end() const + { + return data() + size(); + } + + CUTE_HOST_DEVICE constexpr + const_iterator cend() + { + return end(); + } + + CUTE_HOST_DEVICE constexpr + const_iterator cend() const + { + return end(); + } + + CUTE_HOST_DEVICE constexpr + bool empty() const + { + return size() == 0; + } + + CUTE_HOST_DEVICE constexpr + size_type size() const + { + return N; + } + + CUTE_HOST_DEVICE constexpr + size_type max_size() const + { + return size(); + } + + CUTE_HOST_DEVICE constexpr + void fill(const T& value) + { + for (auto& e : *this) { + e = value; + } + } + + CUTE_HOST_DEVICE constexpr + void clear() + { + fill(T(0)); + } + + CUTE_HOST_DEVICE constexpr + void swap(array& other) + { + using CUTE_STL_NAMESPACE::swap; + for (size_type i = 0; i < size(); ++i) { + swap((*this)[i], other[i]); + } + } + + element_type __elems_[N]; +}; + + +template +struct array +{ + using element_type = T; + using value_type = remove_cv_t; + using size_type = size_t; + using difference_type = ptrdiff_t; + using reference = element_type&; + using const_reference = const element_type&; + using pointer = element_type*; + using const_pointer = const element_type*; + using const_iterator = const_pointer; + using iterator = pointer; + + CUTE_HOST_DEVICE constexpr + reference operator[](size_type pos) + { + return begin()[pos]; + } + + CUTE_HOST_DEVICE constexpr + const_reference operator[](size_type pos) const + { + return begin()[pos]; + } + + CUTE_HOST_DEVICE constexpr + reference front() + { + return *begin(); + } + + CUTE_HOST_DEVICE constexpr + const_reference front() const + { + return *begin(); + } + + CUTE_HOST_DEVICE constexpr + reference back() + { + return *begin(); + } + + CUTE_HOST_DEVICE constexpr + const_reference back() const + { + return *begin(); + } + + CUTE_HOST_DEVICE constexpr + T* data() + { + return nullptr; + } + + CUTE_HOST_DEVICE constexpr + T const* data() const + { + return nullptr; + } + + CUTE_HOST_DEVICE constexpr + iterator begin() + { + return nullptr; + } + + CUTE_HOST_DEVICE constexpr + const_iterator begin() const + { + return nullptr; + } + + CUTE_HOST_DEVICE constexpr + const_iterator cbegin() + { + return nullptr; + } + + CUTE_HOST_DEVICE constexpr + const_iterator cbegin() const + { + return nullptr; + } + + CUTE_HOST_DEVICE constexpr + iterator end() + { + return nullptr; + } + + CUTE_HOST_DEVICE constexpr + const_iterator end() const + { + return nullptr; + } + + CUTE_HOST_DEVICE constexpr + const_iterator cend() + { + return nullptr; + } + + CUTE_HOST_DEVICE constexpr + const_iterator cend() const + { + return nullptr; + } + + CUTE_HOST_DEVICE constexpr + bool empty() const + { + return true; + } + + CUTE_HOST_DEVICE constexpr + size_type size() const + { + return 0; + } + + CUTE_HOST_DEVICE constexpr + size_type max_size() const + { + return 0; + } + + CUTE_HOST_DEVICE constexpr + void fill(const T& value) + {} + + CUTE_HOST_DEVICE constexpr + void clear() + {} + + CUTE_HOST_DEVICE constexpr + void swap(array& other) + {} +}; + +template +CUTE_HOST_DEVICE constexpr +bool operator==(array const& lhs, array const& rhs) +{ + for (size_t i = 0; i < N; ++i) { + if (lhs[i] != rhs[i]) { + return false; + } + } + return true; +} + +template +CUTE_HOST_DEVICE constexpr +void clear(array& a) +{ + a.fill(T(0)); +} + +template +CUTE_HOST_DEVICE constexpr +void fill(array& a, T const& value) +{ + a.fill(value); +} + +template +CUTE_HOST_DEVICE constexpr +void swap(array& a, array& b) +{ + a.swap(b); +} + +/// @return A cute::array of the elements of @c t in reverse order. +template +CUTE_HOST_DEVICE constexpr +cute::array reverse(cute::array const& t) +{ + if constexpr (N == 0u) { + return t; + } else { + cute::array t_r{}; + for (size_t k = 0; k < N; ++k) { + t_r[k] = t[N - k - 1]; + } + return t_r; + } +} + +} // end cute + + +// +// Specialize tuple-related functionality for cute::array +// + +#if defined(__CUDACC_RTC__) +#include +#else +#include +#endif + +namespace cute +{ + +template +CUTE_HOST_DEVICE constexpr +T& get(array& a) +{ + static_assert(I < N, "Index out of range"); + return a[I]; +} + +template +CUTE_HOST_DEVICE constexpr +T const& get(array const& a) +{ + static_assert(I < N, "Index out of range"); + return a[I]; +} + +template +CUTE_HOST_DEVICE constexpr +T&& get(array&& a) +{ + static_assert(I < N, "Index out of range"); + return cute::move(a[I]); +} + +} // end namespace cute + +namespace CUTE_STL_NAMESPACE +{ + +template +struct tuple_size> + : CUTE_STL_NAMESPACE::integral_constant +{}; + +template +struct tuple_element> +{ + using type = T; +}; + +template +struct tuple_size const> + : CUTE_STL_NAMESPACE::integral_constant +{}; + +template +struct tuple_element const> +{ + using type = T; +}; + +} // end namespace CUTE_STL_NAMESPACE + +#ifdef CUTE_STL_NAMESPACE_IS_CUDA_STD +namespace std +{ + +#if defined(__CUDACC_RTC__) +template +struct tuple_size; + +template +struct tuple_element; +#endif + +template +struct tuple_size> + : CUTE_STL_NAMESPACE::integral_constant +{}; + +template +struct tuple_element> +{ + using type = T; +}; + +template +struct tuple_size const> + : CUTE_STL_NAMESPACE::integral_constant +{}; + +template +struct tuple_element const> +{ + using type = T; +}; + +} // end namespace std +#endif // CUTE_STL_NAMESPACE_IS_CUDA_STD diff --git a/include/cute/container/array_aligned.hpp b/include/cute/container/array_aligned.hpp new file mode 100644 index 0000000000..a9d14a1a25 --- /dev/null +++ b/include/cute/container/array_aligned.hpp @@ -0,0 +1,42 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTE_ALIGNAS +#include // cute::array + +namespace cute +{ + +template +struct CUTE_ALIGNAS(Alignment) array_aligned : cute::array {}; + +} // end namespace cute diff --git a/include/cute/container/array_subbyte.hpp b/include/cute/container/array_subbyte.hpp new file mode 100644 index 0000000000..48d416f45b --- /dev/null +++ b/include/cute/container/array_subbyte.hpp @@ -0,0 +1,662 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Statically sized array of elements that accommodates subbyte trivial types + in a packed storage. +*/ + +#pragma once + +#include + +#include +#include + +namespace cute +{ +// +// Underlying subbyte storage type +// +template +using subbyte_storage_type_t = conditional_t<(cute::sizeof_bits_v <= 8), uint8_t, + conditional_t<(cute::sizeof_bits_v <= 16), uint16_t, + conditional_t<(cute::sizeof_bits_v <= 32), uint32_t, + conditional_t<(cute::sizeof_bits_v <= 64), uint64_t, + conditional_t<(cute::sizeof_bits_v <= 128), uint128_t, + T>>>>>; + +template struct subbyte_iterator; +template struct swizzle_ptr; + +// +// subbyte_reference +// Proxy object for sub-byte element references +// +template +struct subbyte_reference +{ + // Iterator Element type (const or non-const) + using element_type = T; + // Iterator Value type without type qualifier. + using value_type = remove_cv_t; + // Storage type (const or non-const) + using storage_type = conditional_t<(is_const_v), subbyte_storage_type_t const, subbyte_storage_type_t>; + + static_assert(sizeof_bits_v % 8 == 0, "Storage type is not supported"); + + static_assert(sizeof_bits_v <= sizeof_bits_v, + "Size of Element must not be greater than Storage."); + +private: + + // Bitmask for covering one item + static constexpr storage_type BitMask = storage_type(storage_type(-1) >> (sizeof_bits_v - sizeof_bits_v)); + // Flag for fast branching on straddled elements + static constexpr bool is_storage_unaligned = ((sizeof_bits_v % sizeof_bits_v) != 0); + + friend struct subbyte_iterator; + + // Pointer to storage element + storage_type* ptr_ = nullptr; + + // Bit index of value_type starting position within storage_type element. + // RI: 0 <= idx_ < sizeof_bit + uint8_t idx_ = 0; + + // Ctor + template + CUTE_HOST_DEVICE constexpr + subbyte_reference(PointerType* ptr, uint8_t idx = 0) : ptr_(reinterpret_cast(ptr)), idx_(idx) {} + +public: + + // Copy Ctor + CUTE_HOST_DEVICE constexpr + subbyte_reference(subbyte_reference const& other) { + *this = other.get(); + } + + CUTE_HOST_DEVICE constexpr + subbyte_reference(subbyte_reference const& other) { + *this = other.get(); + } + + // Copy Assignment + CUTE_HOST_DEVICE constexpr + subbyte_reference& operator=(subbyte_reference const& other) { + return *this = other.get(); + } + + CUTE_HOST_DEVICE constexpr + subbyte_reference& operator=(subbyte_reference const& other) { + return *this = other.get(); + } + + // Assignment + template + CUTE_HOST_DEVICE constexpr + enable_if_t, subbyte_reference&> operator=(value_type x) + { + static_assert(is_same_v, "Do not specify template arguments!"); + storage_type item = (reinterpret_cast(x) & BitMask); + + // Update the current storage element + storage_type bit_mask_0 = storage_type(BitMask << idx_); + ptr_[0] = storage_type((ptr_[0] & ~bit_mask_0) | (item << idx_)); + + // If value_type is unaligned with storage_type (static) and this is a straddled value (dynamic) + if (is_storage_unaligned && idx_ + sizeof_bits_v > sizeof_bits_v) { + uint8_t straddle_bits = uint8_t(sizeof_bits_v - idx_); + storage_type bit_mask_1 = storage_type(BitMask >> straddle_bits); + // Update the next storage element + ptr_[1] = storage_type((ptr_[1] & ~bit_mask_1) | (item >> straddle_bits)); + } + + return *this; + } + + // Comparison of referenced values + CUTE_HOST_DEVICE constexpr friend + bool operator==(subbyte_reference const& x, subbyte_reference const& y) { return x.get() == y.get(); } + CUTE_HOST_DEVICE constexpr friend + bool operator!=(subbyte_reference const& x, subbyte_reference const& y) { return x.get() != y.get(); } + CUTE_HOST_DEVICE constexpr friend + bool operator< (subbyte_reference const& x, subbyte_reference const& y) { return x.get() < y.get(); } + CUTE_HOST_DEVICE constexpr friend + bool operator> (subbyte_reference const& x, subbyte_reference const& y) { return x.get() > y.get(); } + CUTE_HOST_DEVICE constexpr friend + bool operator<=(subbyte_reference const& x, subbyte_reference const& y) { return x.get() <= y.get(); } + CUTE_HOST_DEVICE constexpr friend + bool operator>=(subbyte_reference const& x, subbyte_reference const& y) { return x.get() >= y.get(); } + + // Value + CUTE_HOST_DEVICE + value_type get() const + { + if constexpr (is_same_v) { // Extract to bool -- potentially faster impl + return bool((*ptr_) & (BitMask << idx_)); + } else { // Extract to value_type + // Extract from the current storage element + auto item = storage_type((ptr_[0] >> idx_) & BitMask); + + // If value_type is unaligned with storage_type (static) and this is a straddled value (dynamic) + if (is_storage_unaligned && idx_ + sizeof_bits_v > sizeof_bits_v) { + uint8_t straddle_bits = uint8_t(sizeof_bits_v - idx_); + storage_type bit_mask_1 = storage_type(BitMask >> straddle_bits); + // Extract from the next storage element + item |= storage_type((ptr_[1] & bit_mask_1) << straddle_bits); + } + + return reinterpret_cast(item); + } + } + + // Extract to type value_type + CUTE_HOST_DEVICE constexpr + operator value_type() const { + return get(); + } + + // Address + CUTE_HOST_DEVICE + subbyte_iterator operator&() const { + return {ptr_, idx_}; + } +}; + +template +CUTE_HOST_DEVICE +void +print(subbyte_reference ref) { + cute::print(ref.get()); +} + +template +CUTE_HOST_DEVICE +void +pretty_print(subbyte_reference ref) { + cute::pretty_print(ref.get()); +} + +// +// subbyte_iterator +// Random-access iterator over subbyte references +// +template +struct subbyte_iterator +{ + // Iterator Element type (const or non-const) + using element_type = T; + // Iterator Value type without type qualifier. + using value_type = remove_cv_t; + // Storage type (const or non-const) + using storage_type = conditional_t<(is_const_v), subbyte_storage_type_t const, subbyte_storage_type_t>; + // Reference proxy type + using reference = subbyte_reference; + + static_assert(sizeof_bits_v % 8 == 0, "Storage type is not supported"); + + static_assert(sizeof_bits_v <= sizeof_bits_v, + "Size of Element must not be greater than Storage."); + +private: + + template friend struct swizzle_ptr; + template friend CUTE_HOST_DEVICE constexpr U* raw_pointer_cast(subbyte_iterator const&); + template friend CUTE_HOST_DEVICE constexpr auto recast_ptr(subbyte_iterator const&); + template friend CUTE_HOST_DEVICE void print(subbyte_iterator const&); + + // Pointer to storage element + storage_type* ptr_; + + // Bit index of value_type starting position within storage_type element. + // RI: 0 <= idx_ < sizeof_bit + uint8_t idx_; + +public: + + // Default Ctor + CUTE_HOST_DEVICE constexpr + subbyte_iterator() : ptr_{nullptr}, idx_{0} {}; + + // Ctor + template + CUTE_HOST_DEVICE constexpr + subbyte_iterator(PointerType* ptr, uint8_t idx = 0) : ptr_(reinterpret_cast(ptr)), idx_(idx) { } + + CUTE_HOST_DEVICE constexpr + reference operator*() const { + return reference(ptr_, idx_); + } + + CUTE_HOST_DEVICE constexpr + subbyte_iterator& operator+=(uint64_t k) { + k = sizeof_bits_v * k + idx_; + ptr_ += k / sizeof_bits_v; + idx_ = k % sizeof_bits_v; + return *this; + } + + CUTE_HOST_DEVICE constexpr + subbyte_iterator operator+(uint64_t k) const { + return subbyte_iterator(ptr_, idx_) += k; + } + + CUTE_HOST_DEVICE constexpr + reference operator[](uint64_t k) const { + return *(*this + k); + } + + CUTE_HOST_DEVICE constexpr + subbyte_iterator& operator++() { + idx_ += sizeof_bits_v; + if (idx_ >= sizeof_bits_v) { + ++ptr_; + idx_ -= sizeof_bits_v; + } + return *this; + } + + CUTE_HOST_DEVICE constexpr + subbyte_iterator operator++(int) { + subbyte_iterator ret(*this); + ++(*this); + return ret; + } + + CUTE_HOST_DEVICE constexpr + subbyte_iterator& operator--() { + if (idx_ >= sizeof_bits_v) { + idx_ -= sizeof_bits_v; + } else { + --ptr_; + idx_ += sizeof_bits_v - sizeof_bits_v; + } + return *this; + } + + CUTE_HOST_DEVICE constexpr + subbyte_iterator operator--(int) { + subbyte_iterator ret(*this); + --(*this); + return ret; + } + + CUTE_HOST_DEVICE constexpr friend + bool operator==(subbyte_iterator const& x, subbyte_iterator const& y) { + return x.ptr_ == y.ptr_ && x.idx_ == y.idx_; + } + CUTE_HOST_DEVICE constexpr friend + bool operator!=(subbyte_iterator const& x, subbyte_iterator const& y) { return !(x == y); } + CUTE_HOST_DEVICE constexpr friend + bool operator< (subbyte_iterator const& x, subbyte_iterator const& y) { + return x.ptr_ < y.ptr_ || (x.ptr_ == y.ptr_ && x.idx_ < y.idx_); + } + CUTE_HOST_DEVICE constexpr friend + bool operator<=(subbyte_iterator const& x, subbyte_iterator const& y) { return !(y < x); } + CUTE_HOST_DEVICE constexpr friend + bool operator> (subbyte_iterator const& x, subbyte_iterator const& y) { return (y < x); } + CUTE_HOST_DEVICE constexpr friend + bool operator>=(subbyte_iterator const& x, subbyte_iterator const& y) { return !(x < y); } +}; + +// Conversion to raw pointer with loss of subbyte index +template +CUTE_HOST_DEVICE constexpr +T* +raw_pointer_cast(subbyte_iterator const& x) { + assert(x.idx_ == 0); + return reinterpret_cast(x.ptr_); +} + +// Conversion to NewT_ with possible loss of subbyte index +template +CUTE_HOST_DEVICE constexpr +auto +recast_ptr(subbyte_iterator const& x) { + using NewT = conditional_t<(is_const_v), NewT_ const, NewT_>; + if constexpr (cute::is_subbyte_v) { // Making subbyte_iter, preserve the subbyte idx + return subbyte_iterator(x.ptr_, x.idx_); + } else { // Not subbyte, assume/assert subbyte idx 0 + return reinterpret_cast(raw_pointer_cast(x)); + } + CUTE_GCC_UNREACHABLE; +} + +// Dynamic pointers have unknown static alignment +template +CUTE_HOST_DEVICE constexpr +Int<0> +max_alignment(subbyte_iterator const& x) { + return {}; +} + +template +CUTE_HOST_DEVICE void +print(subbyte_iterator const& x) { + printf("subptr[%db](%p.%u)", int(sizeof_bits_v), x.ptr_, x.idx_); +} + +template +CUTE_HOST_DEVICE void +print(subbyte_reference const& x) { + print(x.get()); +} + +// +// array_subbyte +// Statically sized array for non-byte-aligned data types +// +template +struct array_subbyte +{ + using element_type = T; + using value_type = remove_cv_t; + using pointer = element_type*; + using const_pointer = element_type const*; + + using size_type = size_t; + using difference_type = ptrdiff_t; + + // + // References + // + using reference = subbyte_reference; + using const_reference = subbyte_reference; + + // + // Iterators + // + using iterator = subbyte_iterator; + using const_iterator = subbyte_iterator; + + // Storage type (const or non-const) + using storage_type = conditional_t<(is_const_v), subbyte_storage_type_t const, subbyte_storage_type_t>; + + static_assert(sizeof_bits_v % 8 == 0, "Storage type is not supported"); + +private: + + // Number of storage elements, ceil_div + static constexpr size_type StorageElements = (N * sizeof_bits_v + sizeof_bits_v - 1) / sizeof_bits_v; + + // Internal storage + storage_type storage[StorageElements]; + +public: + + CUTE_HOST_DEVICE constexpr + size_type size() const { + return N; + } + + CUTE_HOST_DEVICE constexpr + size_type max_size() const { + return N; + } + + CUTE_HOST_DEVICE constexpr + bool empty() const { + return !N; + } + + // Efficient clear method + CUTE_HOST_DEVICE constexpr + void clear() { + CUTE_UNROLL + for (size_type i = 0; i < StorageElements; ++i) { + storage[i] = storage_type(0); + } + } + + CUTE_HOST_DEVICE constexpr + void fill(T const& value) { + CUTE_UNROLL + for (size_type i = 0; i < N; ++i) { + at(i) = value; + } + } + + CUTE_HOST_DEVICE constexpr + reference at(size_type pos) { + return iterator(storage)[pos]; + } + + CUTE_HOST_DEVICE constexpr + const_reference at(size_type pos) const { + return const_iterator(storage)[pos]; + } + + CUTE_HOST_DEVICE constexpr + reference operator[](size_type pos) { + return at(pos); + } + + CUTE_HOST_DEVICE constexpr + const_reference operator[](size_type pos) const { + return at(pos); + } + + CUTE_HOST_DEVICE constexpr + reference front() { + return at(0); + } + + CUTE_HOST_DEVICE constexpr + const_reference front() const { + return at(0); + } + + CUTE_HOST_DEVICE constexpr + reference back() { + return at(N-1); + } + + CUTE_HOST_DEVICE constexpr + const_reference back() const { + return at(N-1); + } + + // In analogy to std::vector::data(), these functions are deleted to prevent bugs. + // Instead, prefer + // auto* data = raw_pointer_cast(my_subbyte_array.begin()); + // where the type of auto* is implementation-defined and + // with the knowledge that [data, data + my_subbyte_array.size()) may not be a valid range. + CUTE_HOST_DEVICE constexpr + pointer data() = delete; + + CUTE_HOST_DEVICE constexpr + const_pointer data() const = delete; + + CUTE_HOST_DEVICE constexpr + iterator begin() { + return iterator(storage); + } + + CUTE_HOST_DEVICE constexpr + const_iterator begin() const { + return const_iterator(storage); + } + + CUTE_HOST_DEVICE constexpr + const_iterator cbegin() const { + return begin(); + } + + CUTE_HOST_DEVICE constexpr + iterator end() { + return iterator(storage) + N; + } + + CUTE_HOST_DEVICE constexpr + const_iterator end() const { + return const_iterator(storage) + N; + } + + CUTE_HOST_DEVICE constexpr + const_iterator cend() const { + return end(); + } + + // + // Comparison operators + // + +}; + +// +// Operators +// + +template +CUTE_HOST_DEVICE constexpr +void clear(array_subbyte& a) +{ + a.clear(); +} + +template +CUTE_HOST_DEVICE constexpr +void fill(array_subbyte& a, T const& value) +{ + a.fill(value); +} + +} // namespace cute + +// +// Specialize tuple-related functionality for cute::array_subbyte +// + +#if defined(__CUDACC_RTC__) +#include +#else +#include +#endif + +namespace cute +{ + +template +CUTE_HOST_DEVICE constexpr +T& get(array_subbyte& a) +{ + static_assert(I < N, "Index out of range"); + return a[I]; +} + +template +CUTE_HOST_DEVICE constexpr +T const& get(array_subbyte const& a) +{ + static_assert(I < N, "Index out of range"); + return a[I]; +} + +template +CUTE_HOST_DEVICE constexpr +T&& get(array_subbyte&& a) +{ + static_assert(I < N, "Index out of range"); + return cute::move(a[I]); +} + +} // end namespace cute + +namespace CUTE_STL_NAMESPACE +{ + +template +struct is_reference> + : CUTE_STL_NAMESPACE::true_type +{}; + + +template +struct tuple_size> + : CUTE_STL_NAMESPACE::integral_constant +{}; + +template +struct tuple_element> +{ + using type = T; +}; + +template +struct tuple_size> + : CUTE_STL_NAMESPACE::integral_constant +{}; + +template +struct tuple_element> +{ + using type = T; +}; + +} // end namespace CUTE_STL_NAMESPACE + +#ifdef CUTE_STL_NAMESPACE_IS_CUDA_STD +namespace std +{ + +#if defined(__CUDACC_RTC__) +template +struct tuple_size; + +template +struct tuple_element; +#endif + +template +struct tuple_size> + : CUTE_STL_NAMESPACE::integral_constant +{}; + +template +struct tuple_element> +{ + using type = T; +}; + +template +struct tuple_size> + : CUTE_STL_NAMESPACE::integral_constant +{}; + +template +struct tuple_element> +{ + using type = T; +}; + +} // end namespace std +#endif // CUTE_STL_NAMESPACE_IS_CUDA_STD diff --git a/include/cute/container/bit_field.hpp b/include/cute/container/bit_field.hpp new file mode 100644 index 0000000000..d7fac42a54 --- /dev/null +++ b/include/cute/container/bit_field.hpp @@ -0,0 +1,133 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Portable bit field that supports byte and word straddling that can + be used in unions to bit-wise define parameters. +*/ + +#pragma once + +#include // CUTE_HOST_DEVICE +#include // uint_bit_t +#include // cute::is_same + +namespace cute +{ + +class dummy_type {}; + +template +struct bit_field +{ + static_assert(0 < NumBits && NumBits <= 64, "bit_fields with more than 64 bits are not supported."); + + // value_type: Use the smallest value type that fits NumBits + static constexpr uint32_t value_type_bits = (NumBits <= 8) ? 8 : + (NumBits <= 16) ? 16 : + (NumBits <= 32) ? 32 : 64; + using value_type = cute::uint_bit_t; + // storage_type: Use the smallest storage_type that avoids boundary crossing + static constexpr uint32_t storage_type_bits = (BitStart / 8 == (BitStart + NumBits - 1) / 8) ? 8 : + (BitStart / 16 == (BitStart + NumBits - 1) / 16) ? 16 : + (BitStart / 32 == (BitStart + NumBits - 1) / 32) ? 32 : 64; + using storage_type = cute::uint_bit_t; + + static_assert(sizeof(OtherValueType) == sizeof(value_type) || is_same::value, + "sizeof(OtherValueType) must be same as sizeof(value_type)."); + + // Number of storage values needed: ceil_div(BitStart + NumBits, storage_type_bits) + static constexpr uint32_t N = (BitStart + NumBits + storage_type_bits - 1) / storage_type_bits; + // Index of storage value for BitStart + static constexpr uint32_t idx = BitStart / storage_type_bits; + // Bit of data_[idx] for BitStart + static constexpr uint32_t bit_lo = BitStart % storage_type_bits; + // Number of bits in data_[idx] used for NumBits if straddling, else 0 + static constexpr uint32_t bit_hi = (idx + 1 < N) ? (storage_type_bits - bit_lo) : 0; + +public: + + // NumBits mask + static constexpr value_type mask = value_type(uint64_t(-1) >> (64u - NumBits)); + // NumBits mask for BitStart + static constexpr storage_type mask_lo = storage_type(mask) << bit_lo; + // NumBits mask for leftover bits in data_[idx+1] if straddling, else 0 + static constexpr storage_type mask_hi = (idx + 1 < N) ? (storage_type(mask) >> bit_hi) : 0; + + storage_type data_[N]; + + // Get value + CUTE_HOST_DEVICE constexpr + value_type get() const { + storage_type result = (data_[idx] & mask_lo) >> bit_lo; + if constexpr (bit_hi != 0) { + result |= (data_[idx+1] & mask_hi) << bit_hi; + } + return static_cast(result); + } + + // Set value + CUTE_HOST_DEVICE constexpr + void set(value_type x) { + storage_type item = static_cast(x & mask); + data_[idx] = static_cast((data_[idx] & ~mask_lo) | (item << bit_lo)); + if constexpr (bit_hi != 0) { + data_[idx+1] = static_cast((data_[idx+1] & ~mask_hi) | (item >> bit_hi)); + } + } + + // Assign value + CUTE_HOST_DEVICE constexpr + bit_field& operator=(value_type x) { + set(x); + return *this; + } + + // Cast to value + CUTE_HOST_DEVICE constexpr + operator value_type () const { + return get(); + } + + // Assign OtherValueType + CUTE_HOST_DEVICE constexpr + bit_field& operator=(OtherValueType x) { + return *this = *reinterpret_cast(&x); + } + + // Cast to OtherValueType + CUTE_HOST_DEVICE constexpr + operator OtherValueType () const { + value_type x = get(); + return *reinterpret_cast(&x); + } +}; + +} // end namespace cute diff --git a/include/cute/container/cuda_types.hpp b/include/cute/container/cuda_types.hpp new file mode 100644 index 0000000000..fbc314e543 --- /dev/null +++ b/include/cute/container/cuda_types.hpp @@ -0,0 +1,183 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTE_HOST_DEVICE, CUTE_GCC_UNREACHABLE +#include // cute::integral_constant + +namespace cute +{ + +// +// dim3 +// + +using dim3 = ::dim3; + +// MSVC doesn't define its C++ version macro to match +// its C++ language version. This means that when +// building with MSVC, dim3 isn't constexpr-friendly. +template +CUTE_HOST_DEVICE +#if ! defined(_MSC_VER) +constexpr +#endif +uint32_t& get(dim3& a) +{ + static_assert(I < 3, "Index out of range"); + if constexpr (I == 0) { + return a.x; + } else if constexpr (I == 1) { + return a.y; + } else if constexpr (I == 2) { + return a.z; + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE +#if ! defined(_MSC_VER) +constexpr +#endif +uint32_t const& get(dim3 const& a) +{ + static_assert(I < 3, "Index out of range"); + if constexpr (I == 0) { + return a.x; + } else if constexpr (I == 1) { + return a.y; + } else if constexpr (I == 2) { + return a.z; + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE +#if ! defined(_MSC_VER) +constexpr +#endif +uint32_t&& get(dim3&& a) +{ + static_assert(I < 3, "Index out of range"); + if constexpr (I == 0) { + return cute::move(a.x); + } else if constexpr (I == 1) { + return cute::move(a.y); + } else if constexpr (I == 2) { + return cute::move(a.z); + } + + CUTE_GCC_UNREACHABLE; +} + +// Specialize cute::tuple-traits for external types +template <> +struct tuple_size + : integral_constant +{}; + +template +struct tuple_element +{ + using type = uint32_t; +}; + +// +// uint3 +// + +using uint3 = ::uint3; + +template +CUTE_HOST_DEVICE constexpr +uint32_t& get(uint3& a) +{ + static_assert(I < 3, "Index out of range"); + if constexpr (I == 0) { + return a.x; + } else if constexpr (I == 1) { + return a.y; + } else if constexpr (I == 2) { + return a.z; + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +uint32_t const& get(uint3 const& a) +{ + static_assert(I < 3, "Index out of range"); + if constexpr (I == 0) { + return a.x; + } else if constexpr (I == 1) { + return a.y; + } else if constexpr (I == 2) { + return a.z; + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +uint32_t&& get(uint3&& a) +{ + static_assert(I < 3, "Index out of range"); + if constexpr (I == 0) { + return cute::move(a.x); + } else if constexpr (I == 1) { + return cute::move(a.y); + } else if constexpr (I == 2) { + return cute::move(a.z); + } + + CUTE_GCC_UNREACHABLE; +} + +// Specialize cute::tuple-traits for external types +template <> +struct tuple_size + : integral_constant +{}; + +template +struct tuple_element +{ + using type = uint32_t; +}; + +} // end namespace cute diff --git a/include/cute/container/packed_tuple.hpp b/include/cute/container/packed_tuple.hpp new file mode 100644 index 0000000000..c20df2c235 --- /dev/null +++ b/include/cute/container/packed_tuple.hpp @@ -0,0 +1,254 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include +#include +#include + +namespace cute { + +namespace detail { + +// Empty Structure Optimization +template +struct ESO; + +template +static constexpr bool is_first_empty_v = cute::is_empty::value; +template +static constexpr bool is_rest_empty_v = (cute::is_empty::value && ...); + +template +using ESO_t = ESO, is_rest_empty_v, T...>; + +// Empty First and Empty Rest... +template +struct ESO { + CUTE_HOST_DEVICE constexpr + ESO() {} + + CUTE_HOST_DEVICE constexpr + ESO(First const&, Rest const&...) {} +}; + +// NonEmpty First and Empty Rest... +template +struct ESO { + CUTE_HOST_DEVICE constexpr + ESO() : first_{} {} + + CUTE_HOST_DEVICE constexpr + ESO(First const& first, Rest const&...) : first_{first} {} + + First first_; +}; + +// Empty First and NonEmpty Rest... +template +struct ESO { + CUTE_HOST_DEVICE constexpr + ESO() : rest_{} {} + + CUTE_HOST_DEVICE constexpr + ESO(First const&, Rest const&... rest) : rest_{rest...} {} + + ESO_t rest_; +}; + +// NonEmpty T and NonEmpty Rest... +template +struct ESO { + CUTE_HOST_DEVICE constexpr + ESO() : first_{}, rest_{} {} + + CUTE_HOST_DEVICE constexpr + ESO(First const& first, Rest const&... rest) : first_{first}, rest_{rest...} {} + + First first_; + ESO_t rest_; +}; + +// Get Nth value from ESO +template +CUTE_HOST_DEVICE constexpr decltype(auto) getv(ESO const& s) { + if constexpr (N == 0) { + if constexpr (F) { return T{}; } + else { return static_cast(s.first_); } + } else { + if constexpr (R) { return cute::tuple_element_t>{}; } + else { return getv(s.rest_); } + } +} + +template +CUTE_HOST_DEVICE constexpr decltype(auto) getv(ESO& s) { + if constexpr (N == 0) { + if constexpr (F) { return T{}; } + else { return static_cast(s.first_); } + } else { + if constexpr (R) { return cute::tuple_element_t>{}; } + else { return getv(s.rest_); } + } +} + +template +CUTE_HOST_DEVICE constexpr decltype(auto) getv(ESO&& s) { + if constexpr (N == 0) { + if constexpr (F) { return T{}; } + else { return static_cast(s.first_); } + } else { + if constexpr (R) { return cute::tuple_element_t>{}; } + else { return getv(static_cast&&>(s.rest_)); } + } +} + +// findt: Implementation detail of cute::find. +// If X is the first template argument of the tuple, findt returns C. + +template +CUTE_HOST_DEVICE constexpr +auto +findt(ESO const& t) noexcept +{ + if constexpr (cute::is_same_v) { + return C{}; + } + else { + static_assert(sizeof...(Rest) != 0, + "The type does not appear in the argument list of the tuple."); + if constexpr (IsRestEmpty) { + // The rest is empty, so creating an instance of it is cheap. + return cute::detail::findt(ESO_t{}); + } + else { + return cute::detail::findt(t.rest_); + } + } +} + +} // end namespace detail + +// packed_tuple is a tuple type that is a standard-layout type +// whenever all of its template arguments are standard layout types: +// (cute::is_standard_layout_v && ...) implies (cute::is_standard_layout_v>) + +template +struct packed_tuple : detail::ESO_t +{ + CUTE_HOST_DEVICE constexpr + packed_tuple() {} + + CUTE_HOST_DEVICE constexpr + packed_tuple(T const&... ts) + : detail::ESO_t(ts...) + {} +}; + +template <> +struct packed_tuple<> {}; + +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +get(packed_tuple const& t) { + static_assert(I < sizeof...(T), "Index out of range"); + return detail::getv(t); +} + +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +get(packed_tuple& t) { + static_assert(I < sizeof...(T), "Index out of range"); + return detail::getv(t); +} + +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +get(packed_tuple&& t) { + static_assert(I < sizeof...(T), "Index out of range"); + return detail::getv(static_cast&&>(t)); +} + +template +CUTE_HOST_DEVICE constexpr +packed_tuple +make_packed_tuple(T const&... t) +{ + return {t...}; +} + +// Returns the position of type X (as a static integer) in the tuple +// type's argument list. X must be unique in the argument list. +template +CUTE_HOST_DEVICE constexpr +auto +find(packed_tuple const& t) noexcept +{ + return detail::findt(t); +} + +} // end namespace cute + +namespace CUTE_STL_NAMESPACE +{ + +template +struct tuple_size> + : CUTE_STL_NAMESPACE::integral_constant +{}; + +template +struct tuple_element> + : CUTE_STL_NAMESPACE::tuple_element> +{}; + +} // end namespace CUTE_STL_NAMESPACE + +#ifdef CUTE_STL_NAMESPACE_IS_CUDA_STD +namespace std { + +template +struct tuple_size> + : CUTE_STL_NAMESPACE::integral_constant +{}; + +template +struct tuple_element> + : CUTE_STL_NAMESPACE::tuple_element> +{}; + +} // end namespace std +#endif // CUTE_STL_NAMESPACE_IS_CUDA_STD diff --git a/include/cute/container/tuple.hpp b/include/cute/container/tuple.hpp new file mode 100644 index 0000000000..3123a68d83 --- /dev/null +++ b/include/cute/container/tuple.hpp @@ -0,0 +1,744 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include +#include // cute::true_type, cute::false_type +#include + +#include +#include +#if defined(CUTLASS_USE_PACKED_TUPLE) +# include +#endif + +//#include // Advanced optimizations + +// cute::tuple is like std::tuple, with two differences. +// +// 1. It works on both host and device. +// 2. Its template arguments must be semiregular types. +// +// Semiregular types are default constructible and copyable. +// They include "value types" like int or float, +// but do _not_ include references like int& or float&. +// (See std::tie for an example of a tuple of references.) +// +// If the template arguments of cute::tuple are all empty types (in +// the sense of std::is_empty_v), then the cute::tuple is also an +// empty type. Furthermore, if CUTLASS_USE_PACKED_TUPLE is defined, +// cute::tuple is always a standard-layout type if all of its template +// arguments are standard-layout types. + +namespace cute +{ + +#if defined(CUTLASS_USE_PACKED_TUPLE) + +template +using tuple = packed_tuple; + +#else + +namespace detail +{ + +// This is simplified over the implementations in std::, cuda::std::, and thrust:: by ignoring much of +// the conversion SFINAE, special overloading, and avoiding cvref template types. +// +// Over standard-conforming tuple implementations, this appears to accelerate compilation times by over 3x. + +// EBO stands for "empty base optimization." +// We use this technique to ensure that cute::tuple +// doesn't need to waste space storing any template arguments +// of cute::tuple that have no data (like integral_constant). +// Otherwise, cute::tuple would need to spend at least 1 byte +// for each of its template arguments. +// +// This is one way in which cute::tuple differs from std::tuple. +// Empty types in the template argument list are not even constructed, +// and do not have unique element addresses. In fact, they are not +// even members of the tuple or stored in any way. Calling `get` +// constructs and returns an instance of an empty type on demand. +// +// EBO always "holds" a single value of type T. +// N is like an array index that TupleBase uses +// to access the desired tuple element. +template ::value> +struct EBO; + +template +CUTE_HOST_DEVICE constexpr C findt(EBO const&) +{ return {}; } + +// Specialization for types T that have no data; +// the "static tuple leaf." Valid T here include +// integral_constant, Int, +// and any other semiregular type +// for which std::is_empty_v is true. +template +struct EBO +{ + CUTE_HOST_DEVICE constexpr + EBO() {} + + CUTE_HOST_DEVICE constexpr + EBO(T const&) {} +}; + +template +CUTE_HOST_DEVICE constexpr T getv(EBO const&) +{ return {}; } + +// Specialization for types T that are not empty; +// the "dynamic tuple leaf." Valid T here include int, +// any other integral or floating-point type, +// or any semiregular type for which std::is_empty_v is false. +template +struct EBO +{ + CUTE_HOST_DEVICE constexpr + EBO() : t_{} {} + + CUTE_HOST_DEVICE constexpr + EBO(T const& t) : t_{t} {} + + T t_; +}; + +template +CUTE_HOST_DEVICE constexpr T const& getv(EBO const& x) +{ return x.t_; } + +template +CUTE_HOST_DEVICE constexpr T& getv(EBO& x) +{ return x.t_; } + +template +CUTE_HOST_DEVICE constexpr T&& getv(EBO&& x) +{ return cute::move(x.t_); } + +template +struct TupleBase; + +// Base class of cute::tuple binds each element to an index +// by inheriting from EBO for each (i, t) in (I..., T...). +// The storage (for nonempty t) lives in the base classes. +template +struct TupleBase, T...> + : EBO... +{ + CUTE_HOST_DEVICE constexpr + TupleBase() {} + + CUTE_HOST_DEVICE constexpr + TupleBase(T const&... t) : EBO(t)... {} +}; + +} // end namespace detail + +// Attempting to use the following commented-out alias +// in the declaration of `struct tuple` causes MSVC 2022 build errors. +// +//template +//using TupleBase = detail::TupleBase, T...>; + +// This is the actual cute::tuple class. +// The storage (if any) lives in TupleBase's EBO base classes. +// +// Inheriting from the above alias TupleBase +// causes MSVC 2022 build errors when assigning one tuple to another: +// In summary: this is verbose as a work-around for MSVC build errors. +template +struct tuple : detail::TupleBase, T...> +{ + CUTE_HOST_DEVICE constexpr + tuple() {} + + CUTE_HOST_DEVICE constexpr + tuple(T const&... t) : detail::TupleBase, T...>(t...) {} +}; + +template <> +struct tuple<> +{}; + +// +// get for cute::tuple (just like std::get for std::tuple) +// + +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +get(tuple const& t) noexcept +{ + static_assert(I < sizeof...(T), "Index out of range"); + return detail::getv(t); +} + +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +get(tuple& t) noexcept +{ + static_assert(I < sizeof...(T), "Index out of range"); + return detail::getv(t); +} + +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +get(tuple&& t) noexcept +{ + static_assert(I < sizeof...(T), "Index out of range"); + return detail::getv(static_cast&&>(t)); +} + +// +// find a type X within a cute::tuple +// Requires X to be unique in tuple +// Returns a static integer +// + +template +CUTE_HOST_DEVICE constexpr +auto +find(tuple const& t) noexcept +{ + return detail::findt(t); +} + +#endif // CUTLASS_USE_PACKED_TUPLE + +// +// Custom is_tuple trait simply checks the existence of tuple_size +// and assumes std::get(.), std::tuple_element +// +namespace detail { + +template +auto has_tuple_size( T*) -> bool_constant<(0 <= tuple_size::value)>; +auto has_tuple_size(...) -> false_type; + +} // end namespace detail + +template +struct is_tuple : decltype(detail::has_tuple_size((T*)0)) {}; + +template +constexpr bool is_tuple_v = cute::is_tuple::value; + +// +// make_tuple (value-based implementation) +// + +template +CUTE_HOST_DEVICE constexpr +tuple +make_tuple(T const&... t) +{ + return {t...}; +} + +// +// tuple_cat concatenates multiple cute::tuple into a single cute::tuple, +// just like std::tuple_cat for std::tuple. +// + +#if 0 +// Original implementation + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1, + index_sequence, index_sequence) +{ + return cute::make_tuple(get(t0)..., get(t1)...); +} + +} // end namespace detail + +CUTE_HOST_DEVICE constexpr +tuple<> +tuple_cat() +{ + return {}; +} + +template ::value)> +CUTE_HOST_DEVICE constexpr +Tuple const& +tuple_cat(Tuple const& t) +{ + return t; +} + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1) +{ + return detail::tuple_cat(t0, t1, + make_index_sequence::value>{}, + make_index_sequence::value>{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, Ts const&... ts) +{ + return cute::tuple_cat(cute::tuple_cat(t0,t1),t2,ts...); +} +#endif + +#if 1 +// Extended implementation + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1, + index_sequence, index_sequence) +{ + return cute::make_tuple(get(t0)..., get(t1)...); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, + index_sequence, index_sequence, index_sequence) +{ + return cute::make_tuple(get(t0)..., get(t1)..., get(t2)...); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3, + index_sequence, index_sequence, index_sequence, index_sequence) +{ + return cute::make_tuple(get(t0)..., get(t1)..., get(t2)..., get(t3)...); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3, T4 const& t4, + index_sequence, index_sequence, index_sequence, index_sequence, index_sequence) +{ + return cute::make_tuple(get(t0)..., get(t1)..., get(t2)..., get(t3)..., get(t4)...); +} + +template +struct tuple_cat_static; + +template +struct tuple_cat_static, tuple> { + using type = tuple; +}; + +} // end namespace detail + +CUTE_HOST_DEVICE constexpr +tuple<> +tuple_cat() +{ + return {}; +} + +template ::value)> +CUTE_HOST_DEVICE constexpr +Tuple const& +tuple_cat(Tuple const& t) +{ + return t; +} + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1) +{ + if constexpr (is_static::value && is_static::value && + is_tuple::value && is_tuple::value) { + return typename detail::tuple_cat_static::type{}; + } else { + return detail::tuple_cat(t0, t1, + make_index_sequence::value>{}, + make_index_sequence::value>{}); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2) +{ + return detail::tuple_cat(t0, t1, t2, + make_index_sequence::value>{}, + make_index_sequence::value>{}, + make_index_sequence::value>{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3) +{ + return detail::tuple_cat(t0, t1, t2, t3, + make_index_sequence::value>{}, + make_index_sequence::value>{}, + make_index_sequence::value>{}, + make_index_sequence::value>{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3, T4 const& t4) +{ + return detail::tuple_cat(t0, t1, t2, t3, t4, + make_index_sequence::value>{}, + make_index_sequence::value>{}, + make_index_sequence::value>{}, + make_index_sequence::value>{}, + make_index_sequence::value>{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1, T2 const& t2, T3 const& t3, T4 const& t4, T5 const& t5, Ts const&... ts) +{ + return cute::tuple_cat(cute::tuple_cat(t0,t1,t2,t3,t4), cute::tuple_cat(t5, ts...)); +} +#endif + +#if 0 +// Outer-Inner indexing trick to concat all tuples at once + +namespace detail { + +template +struct tuple_cat_helper +{ + static constexpr cute::array ns = {Ns...}; + + static constexpr size_t total_size() { + size_t sum = 0; + for (size_t n : ns) sum += n; + return sum; + } + static constexpr size_t total_size_ = total_size(); + + static constexpr auto values() { + cute::array outer_inner = {}; + + size_t idx = 0; + for (size_t i = 0; i < ns.size(); ++i) { + for (size_t j = 0; j < ns[i]; ++j, ++idx) { + outer_inner[idx][0] = i; + outer_inner[idx][1] = j; + } + } + return outer_inner; + } + static constexpr auto outer_inner_ = values(); + + using total_sequence = make_index_sequence; +}; + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(Tuple const& t, index_sequence) +{ + return cute::make_tuple(get(get(t))...); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1, + index_sequence, index_sequence) +{ + return cute::make_tuple(get(t0)..., get(t1)...); +} + +} // end namespace detail + +CUTE_HOST_DEVICE constexpr +tuple<> +tuple_cat() +{ + return {}; +} + +template ::value)> +CUTE_HOST_DEVICE constexpr +Tuple const& +tuple_cat(Tuple const& t) +{ + return t; +} + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(T0 const& t0, T1 const& t1) +{ + return detail::tuple_cat(t0, t1, + make_index_sequence::value>{}, + make_index_sequence::value>{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tuple_cat(Tuples const&... ts) +{ + using Helper = detail::tuple_cat_helper::value...>; + return detail::tuple_cat(cute::make_tuple(ts...), typename Helper::total_sequence{}); +} +#endif + +// +// Equality operators +// + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +equal_impl(TupleA const& a, TupleB const& b, index_sequence) +{ + return (cute::true_type{} && ... && (get(a) == get(b))); +} + +} // end namespace detail + +template ::value && is_tuple::value)> +CUTE_HOST_DEVICE constexpr +auto +operator==(TupleT const& t, TupleU const& u) +{ + if constexpr (tuple_size::value == tuple_size::value) { + return detail::equal_impl(t, u, make_index_sequence::value>{}); + } else { + return cute::false_type{}; + } + + CUTE_GCC_UNREACHABLE; +} + +template ::value ^ is_tuple::value)> +CUTE_HOST_DEVICE constexpr +auto +operator==(TupleT const& t, TupleU const& u) +{ + return cute::false_type{}; +} + +template ::value && is_tuple::value)> +CUTE_HOST_DEVICE constexpr +auto +operator!=(TupleT const& t, TupleU const& u) +{ + return !(t == u); +} + +template ::value ^ is_tuple::value)> +CUTE_HOST_DEVICE constexpr +auto +operator!=(TupleT const& t, TupleU const& u) +{ + return cute::true_type{}; +} + +// +// Comparison operators +// + +// +// There are many ways to compare tuple of elements and because CuTe is built +// on parameterizing layouts of coordinates, some comparisons are appropriate +// only in certain cases. +// -- lexicographical comparison [reverse, reflected, revref] +// -- colexicographical comparison [reverse, reflected, revref] +// -- element-wise comparison [any,all] +// This can be very confusing. To avoid errors in selecting the appropriate +// comparison, op<|op<=|op>|op>= are *not* implemented for cute::tuple. +// +// That said, see int_tuple for more explicitly named common comparison ops. +// + +// +// Display utilities +// + +namespace detail { + +template +CUTE_HOST_DEVICE void print_tuple(Tuple const& t, index_sequence, char s = '(', char e = ')') +{ + using cute::print; + if (sizeof...(Is) == 0) { + print(s); + } else { + ((void(print(Is == 0 ? s : ',')), void(print(get(t)))), ...); + } + print(e); +} + +#if !defined(__CUDACC_RTC__) +template +CUTE_HOST std::ostream& print_tuple_os(std::ostream& os, Tuple const& t, index_sequence, char s = '(', char e = ')') +{ + if (sizeof...(Is) == 0) { + os << s; + } else { + (void(os << (Is == 0 ? s : ',') << get(t)), ...); + } + return os << e; +} +#endif // !defined(__CUDACC_RTC__) + +} // end namespace detail + +template ::value)> +CUTE_HOST_DEVICE void print(Tuple const& t) +{ + return detail::print_tuple(t, make_index_sequence::value>{}); +} + +#if !defined(__CUDACC_RTC__) +template ::value)> +CUTE_HOST std::ostream& operator<<(std::ostream& os, Tuple const& t) +{ + return detail::print_tuple_os(os, t, make_index_sequence::value>{}); +} +#endif // !defined(__CUDACC_RTC__) + +} // end namespace cute + +#if ! defined(CUTLASS_USE_PACKED_TUPLE) + +namespace CUTE_STL_NAMESPACE +{ + +template +struct tuple_size> + : CUTE_STL_NAMESPACE::integral_constant +{}; + +template +struct tuple_element> + : CUTE_STL_NAMESPACE::tuple_element> +{}; + +template +struct tuple_size> + : CUTE_STL_NAMESPACE::integral_constant +{}; + +template +struct tuple_element> + : CUTE_STL_NAMESPACE::tuple_element> +{}; + +} // end namespace CUTE_STL_NAMESPACE + +// +// std compatibility +// + +#ifdef CUTE_STL_NAMESPACE_IS_CUDA_STD +namespace std +{ + +#if defined(__CUDACC_RTC__) +template +struct tuple_size; + +template +struct tuple_element; +#endif + +template +struct tuple_size> + : CUTE_STL_NAMESPACE::integral_constant +{}; + +template +struct tuple_element> + : CUTE_STL_NAMESPACE::tuple_element> +{}; + +template +struct tuple_size> + : CUTE_STL_NAMESPACE::integral_constant +{}; + +template +struct tuple_element> + : CUTE_STL_NAMESPACE::tuple_element> +{}; + +} // end namespace std +#endif // CUTE_STL_NAMESPACE_IS_CUDA_STD + +#endif // CUTLASS_USE_PACKED_TUPLE diff --git a/include/cute/container/type_list.hpp b/include/cute/container/type_list.hpp new file mode 100644 index 0000000000..a15f2c1c15 --- /dev/null +++ b/include/cute/container/type_list.hpp @@ -0,0 +1,124 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTE_HOST_DEVICE, CUTE_STL_NAMESPACE + +namespace cute +{ + +template +struct type_list {}; + +// get for type_list +// requires tuple_element_t> to have std::is_default_constructible +template +CUTE_HOST_DEVICE constexpr +CUTE_STL_NAMESPACE::tuple_element_t> +get(type_list const& t) noexcept { + return {}; +} + +} // end namespace cute + +// +// Specialize tuple-related functionality for cute::type_list +// + +#if defined(__CUDACC_RTC__) +#include +#else +#include +#endif + +namespace CUTE_STL_NAMESPACE +{ + +template +struct tuple_size> + : CUTE_STL_NAMESPACE::integral_constant +{}; + +template +struct tuple_element> +{ + using type = typename CUTE_STL_NAMESPACE::tuple_element>::type; +}; + +template +struct tuple_size> + : CUTE_STL_NAMESPACE::integral_constant +{}; + +template +struct tuple_element> +{ + using type = typename CUTE_STL_NAMESPACE::tuple_element>::type; +}; + +} // end namespace std + +#ifdef CUTE_STL_NAMESPACE_IS_CUDA_STD +namespace std +{ + +#if defined(__CUDACC_RTC__) +template +struct tuple_size; + +template +struct tuple_element; +#endif + +template +struct tuple_size> + : CUTE_STL_NAMESPACE::integral_constant +{}; + +template +struct tuple_element> +{ + using type = typename CUTE_STL_NAMESPACE::tuple_element>::type; +}; + +template +struct tuple_size> + : CUTE_STL_NAMESPACE::integral_constant +{}; + +template +struct tuple_element> +{ + using type = typename CUTE_STL_NAMESPACE::tuple_element>::type; +}; + +} // end namespace std +#endif // CUTE_STL_NAMESPACE_IS_CUDA_STD diff --git a/include/cute/int_tuple.hpp b/include/cute/int_tuple.hpp new file mode 100644 index 0000000000..95d06bbdd7 --- /dev/null +++ b/include/cute/int_tuple.hpp @@ -0,0 +1,864 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTE_HOST_DEVICE +#include // cute::array +#include // cute::is_tuple +#include // cute::Int +#include // cute::transform + +/** IntTuple is an integer or a tuple of IntTuples. + * This file holds utilities for working with IntTuples, + * but does not hold a concrete concept or class of IntTuple. + */ + +namespace cute +{ + +// Implementation of get<0>(Integral). +// Even though is_tuple is false and tuple_size doesn't compile, +// CuTe defines rank(Integral) as 1, so it's useful for get<0>(Integral) to return its input +template >::value)> +CUTE_HOST_DEVICE constexpr +decltype(auto) +get(T&& t) noexcept +{ + static_assert(I == 0, "Index out of range"); + return static_cast(t); +} + +// Custom recursive get for anything that implements get(.) (for a single integer I). +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +get(T&& t) noexcept +{ + return get(get(static_cast(t))); +} + +// +// rank +// + +template +CUTE_HOST_DEVICE constexpr +auto +rank(IntTuple const& t) +{ + if constexpr (sizeof...(Is) == 0) { + if constexpr (is_tuple::value) { + return Int::value>{}; + } else { + return Int<1>{}; + } + } else { + return rank(get(t)); + } + + CUTE_GCC_UNREACHABLE; +} + +template +using rank_t = decltype(rank(declval())); + +template +static constexpr auto rank_v = rank_t::value; + +// +// shape +// + +template +CUTE_HOST_DEVICE constexpr +auto +shape(IntTuple const& s) +{ + if constexpr (is_tuple::value) { + return transform(s, [](auto const& a) { return shape(a); }); + } else { + return s; + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +shape(IntTuple const& s) +{ + if constexpr (is_tuple::value) { + return shape(get(s)); + } else { + return get(shape(s)); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// max +// + +template +CUTE_HOST_DEVICE constexpr +auto +max(T0 const& t0, Ts const&... ts) +{ + if constexpr (is_tuple::value) { + return cute::max(cute::apply(t0, [](auto const&... a){ return cute::max(a...); }), ts...); + } else if constexpr (sizeof...(Ts) == 0) { + return t0; + } else { + return cute::max(t0, cute::max(ts...)); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// min +// + +template +CUTE_HOST_DEVICE constexpr +auto +min(T0 const& t0, Ts const&... ts) +{ + if constexpr (is_tuple::value) { + return cute::min(cute::apply(t0, [](auto const&... a){ return cute::min(a...); }), ts...); + } else if constexpr (sizeof...(Ts) == 0) { + return t0; + } else { + return cute::min(t0, cute::min(ts...)); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// gcd +// + +template +CUTE_HOST_DEVICE constexpr +auto +gcd(T0 const& t0, Ts const&... ts) +{ + if constexpr (is_tuple::value) { + return cute::gcd(cute::apply(t0, [](auto const&... a){ return cute::gcd(a...); }), ts...); + } else if constexpr (sizeof...(Ts) == 0) { + return t0; + } else { + return cute::gcd(t0, cute::gcd(ts...)); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// depth +// + +template +CUTE_HOST_DEVICE constexpr +auto +depth(IntTuple const& t) +{ + if constexpr (sizeof...(Is) == 0) { + if constexpr (is_tuple::value) { + return Int<1>{} + cute::apply(t, [](auto const&... v){ return cute::max(depth(v)...); }); + } else { + return Int<0>{}; + } + } else { + return depth(get(t)); + } + + CUTE_GCC_UNREACHABLE; +} + +template +using depth_t = decltype(depth(declval())); + +template +static constexpr auto depth_v = depth_t::value; + +// +// product +// + +// Implementation of product as a function object +struct Product +{ + template + CUTE_HOST_DEVICE constexpr + auto + operator()(IntTuple const& a) const + { + if constexpr (is_tuple::value) { + if constexpr (tuple_size::value == 0) { + return Int<1>{}; + } else { + return cute::transform_apply(a, Product{}, multiplies_unary_lfold{}); + } + } else if constexpr (cute::is_integral::value) { + return a; + } + + CUTE_GCC_UNREACHABLE; + } +}; +// Callable product function object +CUTE_INLINE_CONSTANT Product product; + +// Return a rank(t) tuple @a result such that get(@a result) = product(get(@a t)) +template +CUTE_HOST_DEVICE constexpr +auto +product_each(Tuple const& t) +{ + return transform(wrap(t), product); +} + +// Take the product of Tuple at the leaves of TupleG +template +CUTE_HOST_DEVICE constexpr +auto +product_like(Tuple const& tuple, TupleG const& guide) +{ + return transform_leaf(guide, tuple, [](auto const& g, auto const& t) { return product(t); }); +} + +// Return the product of elements in a mode +template +CUTE_HOST_DEVICE constexpr +auto +size(IntTuple const& a) +{ + if constexpr (sizeof...(Is) == 0) { + return product(a); + } else { + return size(get(a)); + } + + CUTE_GCC_UNREACHABLE; +} + +template +static constexpr auto size_v = decltype(size(declval()))::value; + +// +// sum +// + +template +CUTE_HOST_DEVICE constexpr +auto +sum(IntTuple const& a) +{ + if constexpr (is_tuple::value) { + return cute::apply(a, [](auto const&... v){ return (Int<0>{} + ... + sum(v)); }); + } else { + return a; + } + + CUTE_GCC_UNREACHABLE; +} + +// +// inner_product +// + +template +CUTE_HOST_DEVICE constexpr +auto +inner_product(IntTupleA const& a, IntTupleB const& b) +{ + if constexpr (is_tuple::value && is_tuple::value) { + static_assert(tuple_size::value == tuple_size::value, "Mismatched ranks"); + return transform_apply(a, b, [](auto const& x, auto const& y) { return inner_product(x,y); }, + [](auto const&... v) { return (Int<0>{} + ... + v); }); + } else { + return a * b; + } + + CUTE_GCC_UNREACHABLE; +} + +// +// ceil_div +// + +template +CUTE_HOST_DEVICE constexpr +auto +ceil_div(IntTupleA const& a, IntTupleB const& b) +{ + if constexpr (is_tuple::value) { + if constexpr (is_tuple::value) { // tuple tuple + static_assert(tuple_size::value >= tuple_size::value, "Mismatched ranks"); + constexpr int R = tuple_size::value; // Missing ranks in TupleB are implicitly 1 + return transform(a, append(b,Int<1>{}), [](auto const& x, auto const& y) { return ceil_div(x,y); }); + } else { // tuple int + auto const [result, rest] = fold(a, cute::make_tuple(cute::make_tuple(), b), + [] (auto const& init, auto const& ai) { + return cute::make_tuple(append(get<0>(init), ceil_div(ai, get<1>(init))), ceil_div(get<1>(init), ai)); + }); + return result; + } + } else + if constexpr (is_tuple::value) { // int tuple + return ceil_div(a, product(b)); + } else { + return (a + b - Int<1>{}) / b; + } + + CUTE_GCC_UNREACHABLE; +} + +// +// round_up +// Round @a a up to the nearest multiple of @a b. +// For negative numbers, rounds away from zero. +// + +template +CUTE_HOST_DEVICE constexpr +auto +round_up(IntTupleA const& a, IntTupleB const& b) +{ + if constexpr (is_tuple::value && is_tuple::value) { + static_assert(tuple_size::value >= tuple_size::value, "Mismatched ranks"); + constexpr int R = tuple_size::value; // Missing ranks in TupleB are implicitly 1 + return transform(a, append(b,Int<1>{}), [](auto const& x, auto const& y) { return round_up(x,y); }); + } else { + return ((a + b - Int<1>{}) / b) * b; + } + + CUTE_GCC_UNREACHABLE; +} + +/** Division for Shapes + * Case Tuple Tuple: + * Perform shape_div element-wise + * Case Tuple Int: + * Fold the division of b across each element of a + * Example: shape_div((4,5,6),40) -> shape_div((1,5,6),10) -> shape_div((1,1,6),2) -> (1,1,3) + * Case Int Tuple: + * Return shape_div(a, product(b)) + * Case Int Int: + * Enforce the divisibility condition a % b == 0 || b % a == 0 when possible + * Return a / b with rounding away from 0 (that is, 1 or -1 when a < b) + */ +template +CUTE_HOST_DEVICE constexpr +auto +shape_div(IntTupleA const& a, IntTupleB const& b) +{ + if constexpr (is_tuple::value) { + if constexpr (is_tuple::value) { // tuple tuple + static_assert(tuple_size::value == tuple_size::value, "Mismatched ranks"); + return transform(a, b, [](auto const& x, auto const& y) { return shape_div(x,y); }); + } else { // tuple int + auto const [result, rest] = fold(a, cute::make_tuple(cute::make_tuple(), b), + [] (auto const& init, auto const& ai) { + return cute::make_tuple(append(get<0>(init), shape_div(ai, get<1>(init))), shape_div(get<1>(init), ai)); + }); + return result; + } + } else + if constexpr (is_tuple::value) { // int tuple + return shape_div(a, product(b)); + } else + if constexpr (is_static::value && is_static::value) { + static_assert(IntTupleA::value % IntTupleB::value == 0 || IntTupleB::value % IntTupleA::value == 0, "Static shape_div failure"); + return C{}; + } else { // int int + //assert(a % b == 0 || b % a == 0); // Waive dynamic assertion + return a / b != 0 ? a / b : signum(a) * signum(b); // Division with rounding away from zero + } + + CUTE_GCC_UNREACHABLE; +} + +/** Minimum for Shapes + */ +template +CUTE_HOST_DEVICE constexpr +auto +shape_min(IntTupleA const& a, IntTupleB const& b) +{ + if constexpr (is_tuple::value || is_tuple::value) { + static_assert(dependent_false, "Not implemented."); + } else + if constexpr (is_constant<1, IntTupleA>::value || is_constant<1, IntTupleB>::value) { + return Int<1>{}; // _1 is less than all other shapes, preserve static + } else { + return cute::min(a, b); + } + + CUTE_GCC_UNREACHABLE; +} + +/** Return a tuple the same profile as A scaled by corresponding elements in B + */ +template +CUTE_HOST_DEVICE constexpr +auto +elem_scale(A const& a, B const& b) +{ + if constexpr (is_tuple::value) { + return transform(a, b, [](auto const& x, auto const& y) { return elem_scale(x,y); }); + } else { + return a * product(b); + } + + CUTE_GCC_UNREACHABLE; +} + +/** Test if two IntTuple have the same profile (hierarchical rank division) + */ +template +CUTE_HOST_DEVICE constexpr +auto +congruent(IntTupleA const& a, IntTupleB const& b) +{ + return bool_constant::value>{}; +} + +template +using is_congruent = decltype(congruent(declval(), declval())); + +/** Test if two IntTuple have the similar profiles up to Shape A (hierarchical rank division) + * weakly_congruent is a partial order on A and B: A <= B + */ +template +CUTE_HOST_DEVICE constexpr +auto +weakly_congruent(IntTupleA const& a, IntTupleB const& b) +{ + if constexpr (is_tuple::value && is_tuple::value) { + if constexpr (tuple_size::value != tuple_size::value) { + return false_type{}; + } else { + return transform_apply(a, b, [](auto const& x, auto const& y) { return weakly_congruent(x,y); }, + [](auto const&... z) { return (true_type{} && ... && z); }); + } + } else if constexpr (is_integral::value) { + return true_type{}; + } else if constexpr (is_integral::value) { + return false_type{}; + } else { + return weakly_congruent(shape(a), shape(b)); + } + + CUTE_GCC_UNREACHABLE; +} + +template +using is_weakly_congruent = decltype(weakly_congruent(declval(), declval())); + +/** Test if Shape A is compatible with Shape B: + * the size of A and B are the same, and + * any coordinate into A can also be used as a coordinate into B + * Equivalently, the size of Shape B is the same as Shape A at each terminal of Shape A. + * compatible is a partial order on A and B: A <= B + */ +template +CUTE_HOST_DEVICE constexpr +auto +compatible(IntTupleA const& a, IntTupleB const& b) +{ + if constexpr (is_tuple::value && is_tuple::value) { + if constexpr (tuple_size::value != tuple_size::value) { + return false_type{}; + } else { + return transform_apply(a, b, [](auto const& x, auto const& y) { return compatible(x,y); }, + [](auto const&... z) { return (true_type{} && ... && z); }); + } + } else if constexpr (is_integral::value) { + return a == size(b); + } else if constexpr (is_integral::value) { + return false_type{}; + } else { + return compatible(shape(a), shape(b)); + } + + CUTE_GCC_UNREACHABLE; +} + +template +using is_compatible = decltype(compatible(declval(), declval())); + +/** Test if Shape A is evenly divided by Tiler B + * @returns Static or dynamic boolean + * @post if result is true_type, then + * size(a) == logical_divide(make_layout(shape(a)),b) will always compile + * and result in true_type. + */ +template +CUTE_HOST_DEVICE constexpr +auto +evenly_divides(Shape const& a, Tiler const& b) +{ + if constexpr (is_tuple::value) { + if constexpr (rank_v > rank_v) { + return false_type{}; + } else { + return transform_apply(b, a, [](auto const& x, auto const& y) { return evenly_divides(y,x); }, + [](auto const&... z) { return (true_type{} && ... && z); }); + } + } else { + return size(a) == size(b) * size(ceil_div(shape(a), b)); + } + + CUTE_GCC_UNREACHABLE; +} + +/** Replace the elements of Tuple B that are paired with an Int<0> with an Int<1> + */ +template +CUTE_HOST_DEVICE constexpr +auto +filter_zeros(IntTupleA const& a, IntTupleB const& b) +{ + if constexpr (is_tuple::value) { + return transform(a, b, [](auto const& x, auto const& y) { return filter_zeros(x,y); }); + } else if constexpr (is_constant<0, IntTupleA>::value) { + return repeat_like(b, Int<1>{}); + } else { + return b; + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +filter_zeros(Tuple const& t) +{ + return filter_zeros(t, t); +} + +// +// Converters and constructors with arrays and params +// + +/** Make an IntTuple of rank N from an Indexable array. + * Access elements up to a dynamic index n, then use init (requires compatible types) + * Consider cute::take if all indexing is known to be valid + * \code + * std::vector a = {6,3,4}; + * auto tup = make_int_tuple<5>(a, a.size(), 0) // (6,3,4,0,0) + * \endcode + */ +template +CUTE_HOST_DEVICE constexpr +auto +make_int_tuple(Indexable const& t, int n, T const& init) +{ + static_assert(N > 0); + if constexpr (N == 1) { + return 0 < n ? t[0] : init; + } else { + return transform(make_seq{}, [&](auto i) { return i < n ? t[i] : init; }); + } + + CUTE_GCC_UNREACHABLE; +} + +/** Fill the dynamic values of a Tuple with values from another Tuple + * \code + * auto params = make_tuple(6,3,4); + * cute::tuple, cute::tuple>, int, Int<2>> result; + * fill_int_tuple_from(result, params); // (_1,(6,3,_3),4,_2) + * \endcode + */ +template +CUTE_HOST_DEVICE constexpr +auto +fill_int_tuple_from(Tuple& result, TupleV const& vals) +{ + return fold(result, vals, [](auto const& init, auto&& r) { + if constexpr (is_static>::value) { // Skip static elements of result + return init; + } else if constexpr (is_tuple>::value) { // Recurse into tuples + return fill_int_tuple_from(r, init); + } else { // Assign and consume arg + static_assert(tuple_size>::value > 0, "Not enough values to fill with!"); + r = get<0>(init); + return remove<0>(init); + } + + CUTE_GCC_UNREACHABLE; + }); +} + +/** Make a "Tuple" by filling in the dynamic values in order from the arguments + * \code + * using result_t = cute::tuple, cute::tuple>, int, Int<2>>; + * auto result = make_int_tuple_from(6,3,4); // (_1,(6,3,_3),4,_2) + * \endcode + */ +template +CUTE_HOST_DEVICE constexpr +Tuple +make_int_tuple_from(Ts const&... ts) +{ + Tuple result = Tuple{}; + fill_int_tuple_from(result, cute::make_tuple(ts...)); + return result; +} + +/** Convert a tuple to a flat homogeneous array of type T + * \code + * auto tup = cute::make_tuple(Int<1>{}, cute::make_tuple(6,3,Int<3>{}),4,Int<2>{}); + * cute::array result = to_array(tup); // [1,6,3,3,4,2] + * \endcode + */ +template +CUTE_HOST_DEVICE constexpr +auto +to_array(IntTuple const& t) +{ + auto flat_t = flatten_to_tuple(t); + constexpr int N = tuple_size::value; + cute::array result; + for_each(make_seq{}, [&] (auto i) { result[i] = get(flat_t); }); + return result; +} + +// +// Comparison operators +// + +// +// There are many ways to compare tuple of elements and because CuTe is built +// on parameterizing layouts of coordinates, some comparisons are appropriate +// only in certain cases. +// -- lexicographical comparison [reverse, reflected, revref] : Correct for coords in RowMajor Layout +// -- colexicographical comparison [reverse, reflected, revref] : Correct for coords in ColMajor Layout +// -- element-wise comparison [any,all] : +// This can be very confusing. To avoid errors in selecting the appropriate +// comparison, op<|op<=|op>|op>= are *not* implemented for cute::tuple. +// +// When actually desiring to order coordinates, the user should map them to +// their indices within the Layout they came from: +// e.g. layoutX(coordA) < layoutX(coordB) +// That said, we implement the three most common ways to compare tuples below. +// These are implemented with slighly more explicit names than op<. +// + +template +CUTE_HOST_DEVICE constexpr +auto +lex_less(IntTupleA const& a, IntTupleB const& b); + +template +CUTE_HOST_DEVICE constexpr +auto +colex_less(IntTupleA const& a, IntTupleB const& b); + +template +CUTE_HOST_DEVICE constexpr +auto +elem_less(IntTupleA const& a, IntTupleB const& b); + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +lex_less_impl(TupleA const& a, TupleB const& b) +{ + if constexpr (I == tuple_size::value) { + return cute::false_type{}; // Terminal: TupleB is exhausted + } else if constexpr (I == tuple_size::value) { + return cute::true_type{}; // Terminal: TupleA is exhausted, TupleB is not exhausted + } else { + return lex_less(get(a), get(b)) || (get(a) == get(b) && lex_less_impl(a,b)); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +colex_less_impl(TupleA const& a, TupleB const& b) +{ + if constexpr (I == tuple_size::value) { + return cute::false_type{}; // Terminal: TupleB is exhausted + } else if constexpr (I == tuple_size::value) { + return cute::true_type{}; // Terminal: TupleA is exhausted, TupleB is not exhausted + } else { + constexpr size_t A = tuple_size::value - 1 - I; + constexpr size_t B = tuple_size::value - 1 - I; + return colex_less(get(a), get(b)) || (get(a) == get(b) && colex_less_impl(a,b)); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +elem_less_impl(TupleA const& a, TupleB const& b) +{ + if constexpr (I == tuple_size::value) { + return cute::true_type{}; // Terminal: TupleA is exhausted + } else if constexpr (I == tuple_size::value) { + return cute::false_type{}; // Terminal: TupleA is not exhausted, TupleB is exhausted + } else { + return elem_less(get(a), get(b)) && elem_less_impl(a,b); + } + + CUTE_GCC_UNREACHABLE; +} + +} // end namespace detail + +// Lexicographical comparison + +template +CUTE_HOST_DEVICE constexpr +auto +lex_less(IntTupleA const& a, IntTupleB const& b) +{ + if constexpr (is_tuple::value && is_tuple::value) { + return detail::lex_less_impl<0>(a, b); + } else { + return a < b; + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +lex_leq(T const& t, U const& u) { + return !lex_less(u, t); +} + +template +CUTE_HOST_DEVICE constexpr +auto +lex_gtr(T const& t, U const& u) { + return lex_less(u, t); +} + +template +CUTE_HOST_DEVICE constexpr +auto +lex_geq(T const& t, U const& u) { + return !lex_less(t, u); +} + +// Colexicographical comparison + +template +CUTE_HOST_DEVICE constexpr +auto +colex_less(IntTupleA const& a, IntTupleB const& b) +{ + if constexpr (is_tuple::value && is_tuple::value) { + return detail::colex_less_impl<0>(a, b); + } else { + return a < b; + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +colex_leq(T const& t, U const& u) { + return !colex_less(u, t); +} + +template +CUTE_HOST_DEVICE constexpr +auto +colex_gtr(T const& t, U const& u) { + return colex_less(u, t); +} + +template +CUTE_HOST_DEVICE constexpr +auto +colex_geq(T const& t, U const& u) { + return !colex_less(t, u); +} + +// Elementwise [all] comparison + +template +CUTE_HOST_DEVICE constexpr +auto +elem_less(IntTupleA const& a, IntTupleB const& b) +{ + if constexpr (is_tuple::value && is_tuple::value) { + return detail::elem_less_impl<0>(a, b); + } else { + return a < b; + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +elem_leq(T const& t, U const& u) { + return !elem_less(u, t); +} + +template +CUTE_HOST_DEVICE constexpr +auto +elem_gtr(T const& t, U const& u) { + return elem_less(u, t); +} + +template +CUTE_HOST_DEVICE constexpr +auto +elem_geq(T const& t, U const& u) { + return !elem_less(t, u); +} + +} // end namespace cute diff --git a/include/cute/layout.hpp b/include/cute/layout.hpp new file mode 100644 index 0000000000..26195a4782 --- /dev/null +++ b/include/cute/layout.hpp @@ -0,0 +1,2058 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include +#include +#include +#include +#include +#include +#include // cute::sizeof_bits + +namespace cute +{ + +// Aliases + +template +using Shape = cute::tuple; + +template +using Stride = cute::tuple; + +template +using Step = cute::tuple; + +template +using Coord = cute::tuple; + +template +using Tile = cute::tuple; + +template +CUTE_HOST_DEVICE constexpr +Shape +make_shape(Ts const&... t) { + return {t...}; +} +template +CUTE_HOST_DEVICE constexpr +Stride +make_stride(Ts const&... t) { + return {t...}; +} +template +CUTE_HOST_DEVICE constexpr +Step +make_step(Ts const&... t) { + return {t...}; +} +template +CUTE_HOST_DEVICE constexpr +Coord +make_coord(Ts const&... t) { + return {t...}; +} +template +CUTE_HOST_DEVICE constexpr +Tile +make_tile(Ts const&... t) +{ + return {t...}; +} + +// +// Layout +// + +template > +struct Layout + : private cute::tuple // EBO for static layouts +{ + // Expensive in compilation time... + //static_assert(is_congruent::value, "Shape and Stride must be congruent"); + + // NOTE: This defaults static Shapes/Strides correctly, but not dynamic + CUTE_HOST_DEVICE constexpr + Layout(Shape const& shape = {}, Stride const& stride = {}) + : cute::tuple(shape, stride) + {} + + // + // Accessors + // + + static constexpr int rank = rank_v; + + CUTE_HOST_DEVICE constexpr + decltype(auto) + layout() { + return *this; + } + + CUTE_HOST_DEVICE constexpr + decltype(auto) + layout() const { + return *this; + } + + template + CUTE_HOST_DEVICE constexpr + decltype(auto) + shape() { + return get<0,I...>(static_cast&>(*this)); + } + + template + CUTE_HOST_DEVICE constexpr + decltype(auto) + shape() const { + return get<0,I...>(static_cast const&>(*this)); + } + + template + CUTE_HOST_DEVICE constexpr + decltype(auto) + stride() { + return get<1,I...>(static_cast&>(*this)); + } + + template + CUTE_HOST_DEVICE constexpr + decltype(auto) + stride() const { + return get<1,I...>(static_cast const&>(*this)); + } + + // + // Mappings + // + + // Map a logical coordinate to a linear index (Coord has no Underscore slice operators) + // OR + // Slice the layout and return the sublayout (Coord has an Underscore slice op) + template + CUTE_HOST_DEVICE constexpr + auto + operator()(Coord const& coord) const { + if constexpr (has_underscore::value) { + return slice(coord, *this); + } else { + return crd2idx(coord, shape(), stride()); + } + + CUTE_GCC_UNREACHABLE; + } + + // Convenience function for multi-dimensional coordinates + template + CUTE_HOST_DEVICE constexpr + auto + operator()(Coord0 const& c0, Coord1 const& c1, Coords const&... cs) const { + return operator()(make_coord(c0,c1,cs...)); + } + + // + // Compose + // + + template + CUTE_HOST_DEVICE constexpr + auto + compose(OtherLayout const& other) const { + return composition(*this, other); + } + + template + CUTE_HOST_DEVICE constexpr + auto + compose(Layouts const&... layouts) const { + return composition(*this, make_tile(layouts...)); + } + + template + CUTE_HOST_DEVICE constexpr + auto + with_shape(OtherShape const& shape) const { + return composition(*this, make_layout(shape)); + } + + template + CUTE_HOST_DEVICE constexpr + auto + with_shape(Shapes const&... shapes) const { + return composition(*this, make_layout(make_shape(shapes...))); + } + + // + // Tile + // + + template + CUTE_HOST_DEVICE constexpr + auto + tile(OtherLayout const& other) const { + return tiled_divide(*this, other); + } + + template + CUTE_HOST_DEVICE constexpr + auto + tile(Layouts const&... layouts) const { + return tiled_divide(*this, make_tile(layouts...)); + } + + // + // Utility + // + + // + // Index to Coordinate + // + + // NOTE: Only valid for compact layouts + + // Return the (hierarchical) ND logical coordinate corresponding to the linear index + // @post crd2idx(@a result, shape(), stride()) == idx + // @post congruent(@a result, shape()) + template ::value)> + CUTE_HOST_DEVICE constexpr + auto + get_hier_coord(IInt const& idx) const { + return cute::idx2crd(idx, shape(), stride()); + } + + // Return the (flat) ND logical coordinate corresponding to the linear index + // @post crd2idx(@a result, shape(), stride()) == idx + // @post rank(@a result) == rank(shape()) && depth(@a result) == 1 + template ::value)> + CUTE_HOST_DEVICE constexpr + auto + get_flat_coord(IInt const& idx) const { + return cute::crd2crd(this->get_hier_coord(idx), shape(), repeat(Int<1>{})); + } + + // Return the generalized column-major 1D logical coordinate corresponding to the linear index + // @post crd2idx(@a result, shape(), stride()) == idx + // @post is_integral::value + template ::value)> + CUTE_HOST_DEVICE constexpr + auto + get_1d_coord(IInt const& idx) const { + return cute::crd2idx(this->get_hier_coord(idx), shape()); + } + + // + // Coordinate to Coordinate + // + +#if 0 + // Return the (hierarchical) ND logical coordinate corresponding to the linear index + // @post congruent(@a result, shape()) + template + CUTE_HOST_DEVICE constexpr + auto + crd_2_hier_coord(Coord const& crd) const { + return cute::crd2crd(crd, shape(), shape()); + } + + // Return the (flat) ND logical coordinate corresponding to the linear index + // @post rank(@a result) == rank(shape()) && depth(@a result) == 1 + template + CUTE_HOST_DEVICE constexpr + auto + crd_2_flat_coord(Coord const& crd) const { + return cute::crd2crd(crd, shape(), product_each(shape())); + } + + // Return the generalized column-major 1D logical coordinate corresponding to the linear index + // @post is_integral::value + template + CUTE_HOST_DEVICE constexpr + auto + crd_2_1d_coord(Coord const& crd) const { + //return cute::crd2crd(crd, shape(), product(shape())); + return cute::crd2idx(crd, shape()); + } +#endif +}; + +// Equality, return a static or dynamic boolean +template +CUTE_HOST_DEVICE constexpr +auto +operator==(Layout const& layoutA, Layout const& layoutB) +{ + return layoutA.shape() == layoutB.shape() && layoutA.stride() == layoutB.stride(); +} + +template +struct is_layout : false_type {}; +template +struct is_layout> : true_type {}; + +// +// Layout construction +// + +template +CUTE_HOST_DEVICE constexpr +auto +make_layout(Shape const& shape, Stride const& stride) +{ + static_assert(is_tuple::value || is_integral::value); + static_assert(is_tuple::value || is_integral::value); + return Layout(shape, stride); +} + +template +CUTE_HOST_DEVICE constexpr +auto +make_layout(Shape const& shape) +{ + static_assert(is_tuple::value || is_integral::value); + return make_layout(shape, compact_major(shape)); +} + +// +// Convenience tags for common layouts +// + +template +CUTE_HOST_DEVICE constexpr +auto +make_layout(Shape const& shape, LayoutLeft) +{ + return make_layout(shape, compact_major(shape)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +make_layout(Shape const& shape, LayoutRight) +{ + return make_layout(shape, compact_major(shape)); +} + +// +// Construct a layout from multiple layouts by concatenation +// + +// One argument overload +template +CUTE_HOST_DEVICE constexpr +auto +make_layout(Layout const& layout0) +{ + return make_layout(make_shape (layout0.shape() ), + make_stride(layout0.stride())); +} + +// Two argument overload +template +CUTE_HOST_DEVICE constexpr +auto +make_layout(Layout const& layout0, + Layout const& layout1) +{ + return make_layout(make_shape (layout0.shape() , layout1.shape() ), + make_stride(layout0.stride(), layout1.stride())); +} + +// Var argument overload +template +CUTE_HOST_DEVICE constexpr +auto +make_layout(Layout const& layout0, + Layout const& layout1, + Layout const&... layouts) +{ + return make_layout(make_shape (layout0.shape() , layout1.shape() , layouts.shape()... ), + make_stride(layout0.stride(), layout1.stride(), layouts.stride()...)); +} + +// +// Advanced Layout constructions +// + +// Make a compact layout with shape @a shape and strides following the order induced by @a order. +// Dynamic values in @a order are ignored, considered large, and considered ordered from left to right. +// Example: +// make_ordered_layout(Shape<_2,_2,_2,_2>{}, Step<_0,_2,_3,_1>{}) +// -> (_2,_2,_2,_2):(_1,_4,_8,_2) +// make_ordered_layout(make_shape(2,3,4,5), make_step(Int<2>{}, 67, 42, Int<50>{})) +// -> (2,3,4,5):(_1,10,30,2) +template +CUTE_HOST_DEVICE constexpr +auto +make_ordered_layout(Shape const& shape, Order const& order) +{ + return make_layout(shape, compact_order(shape, order)); +} + +// Make a compact layout with the same shape as @a layout +// and strides following the order induced by @a layout.stride(). +// Static-0 strides in the input @a layout are preserved in the output. +// Example: +// make_layout_like(Layout, Stride<_0,_2,_4,_1>>{}) +// -> (_2,_2,_2,_2):(_0,_2,_4,_1) +// make_layout_like(make_layout(make_shape(2,3,4,5), make_stride(Int<0>{},42,Int<1>{},Int<0>{}))) +// -> (2,3,4,5):(_0,4,_1,_0) +template +CUTE_HOST_DEVICE constexpr +auto +make_layout_like(Layout const& layout) +{ + return make_layout(layout.shape(), + compact_order(filter_zeros(layout.stride(), layout.shape()), layout.stride())); +} + +// Make a compact layout with the same shape as @a layout +// and strides following the order induced by @a layout.stride(), +// except mode-0 is always stride-1 and generated column-major. +// The 0th mode is commonly used for MMA_Atoms or Copy_Atoms so this +// generates the 0th mode with LayoutLeft (preserving stride-0s) regardless of the reference layout +template +CUTE_HOST_DEVICE constexpr +auto +make_fragment_like(Layout const& layout) +{ + constexpr int R = Layout::rank; + if constexpr (R > 1 && is_static::value) { + return tiled_product(make_layout(get<0>(layout.shape()), + compact_major(filter_zeros(get<0>(layout.stride()), get<0>(layout.shape())))), + make_ordered_layout(take<1,R>(layout.shape()), take<1,R>(layout.stride()))); + } else { + return make_layout(layout.shape()); + } + + CUTE_GCC_UNREACHABLE; +} + +template ::value || is_integral::value)> +CUTE_HOST_DEVICE constexpr +auto +make_fragment_like(Shape const& shape) +{ + return make_layout(shape); +} + +// +// Make an identity layout that maps a coordinate to itself +// + +template +CUTE_HOST_DEVICE constexpr +auto +make_identity_layout(Shape const& shape) +{ + return make_layout(shape, make_basis_like(shape)); +} + +// +// Operations to manipulate Layouts like a tuple of pairs +// + +// Return the Is...th sublayout. +// For Is... = , equivalent to get(...get(get(layout))) +template +CUTE_HOST_DEVICE constexpr +auto +get(Layout const& layout) +{ + return make_layout(get(layout.shape()), + get(layout.stride())); +} + +// Return a new layout with only the modes in the range [B,E) +template +CUTE_HOST_DEVICE constexpr +auto +take(Layout const& layout) +{ + static_assert(B < E, "take: empty range error"); + static_assert(0 <= B && E <= Layout::rank, "take: range out of bounds"); + return make_layout(take(layout.shape()), + take(layout.stride())); +} + +// Return a new layout with only the modes Is... = +template +CUTE_HOST_DEVICE constexpr +auto +select(Layout const& layout) +{ + return make_layout(select(layout.shape()), + select(layout.stride())); +} + +// Return a layout with depth at most 1 +template +CUTE_HOST_DEVICE constexpr +auto +flatten(Layout const& layout) +{ + return make_layout(flatten(layout.shape()), + flatten(layout.stride())); +} + +// Return a layout whose profile is congruent to TargetProfile +// @pre Input layout is flat, flatten(@a layout) == @a layout +// @pre Input layout can be folded to profile, rank(@a layout) == rank(flatten(@a target_profile)) +// @post congruent(@a result, @a target_profile) +template +CUTE_HOST_DEVICE constexpr +auto +unflatten(Layout const& layout, TargetProfile const& target_profile) +{ + return make_layout(unflatten(layout.shape(), target_profile), + unflatten(layout.stride(), target_profile)); +} + +// +// Utilities +// + +// Return the sublayout of mode I... +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +layout(Layout const& layout) +{ + if constexpr (sizeof...(Is) == 0) { + return layout; + } else { + return get(layout); + } + + CUTE_GCC_UNREACHABLE; +} + +// Return the shape of a mode +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +shape(Layout& layout) +{ + return layout.template shape(); +} + +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +shape(Layout const& layout) +{ + return layout.template shape(); +} + +// Return the stride of a mode +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +stride(Layout& layout) +{ + return layout.template stride(); +} + +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +stride(Layout const& layout) +{ + return layout.template stride(); +} + +// Return the number of elements in a mode +template +CUTE_HOST_DEVICE constexpr +auto +size(Layout const& layout) +{ + return size(shape(layout)); +} + +// Return the number of modes +template +CUTE_HOST_DEVICE constexpr +auto +rank(Layout const& layout) +{ + return rank(shape(layout)); +} + +// Return the depth of the layout +template +CUTE_HOST_DEVICE constexpr +auto +depth(Layout const& layout) +{ + return depth(shape(layout)); +} + +// Return the codomain shape of a mode +// @post size(coshape(@a a)) == cosize(@a a) +// @return C Coordinate with smallest elements such that +// @a elem_less(sub_layout(c), C) for all c < size(@a sub_layout) +// where sub_layout = get(layout). +template +CUTE_HOST_DEVICE constexpr +auto +coshape(Layout const& layout) +{ + // Protect against negative strides + auto abs_sub_layout = make_layout(shape(layout), + transform_leaf(stride(layout), abs_fn{})); + auto co_coord = as_arithmetic_tuple(abs_sub_layout(size(abs_sub_layout) - Int<1>{})); + return co_coord + repeat_like(co_coord, Int<1>{}); +} + +// Return the codomain size of a mode +// @return M smallest integer such that +// @a sub_layout(c) < M for all c < size(@a sub_layout) +// where sub_layout = get(layout). +template +CUTE_HOST_DEVICE constexpr +auto +cosize(Layout const& layout) +{ + return size(coshape(layout)); +} + +template +using cosize_t = decltype(cosize(declval())); + +template +static constexpr auto cosize_v = cosize_t::value; + +// With crd2idx(coord, shape), makes sense to have crd2idx(coord, Layout) as well +template +CUTE_HOST_DEVICE constexpr +auto +crd2idx(Coord const& c, Layout const& layout) +{ + return crd2idx(c, layout.shape(), layout.stride()); +} + +// +// Slice and Dice a layout +// + +template +CUTE_HOST_DEVICE constexpr +auto +slice(Coord const& c, Layout const& layout) +{ + return make_layout(slice(c, layout.shape()), + slice(c, layout.stride())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +slice_and_offset(Coord const& c, Layout const& layout) +{ + return cute::make_tuple(slice(c, layout), crd2idx(c, layout)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +dice(Coord const& c, Layout const& layout) +{ + return make_layout(dice(c, layout.shape()), + dice(c, layout.stride())); +} + +// Compute a pointer offset and (potentially modified) layout from a coordinate +// This exists so it can be overloaded for ComposedLayout +template +CUTE_HOST_DEVICE constexpr +auto +domain_offset(Coord const& coord, Layout const& layout) +{ + return cute::make_tuple(layout, layout(coord)); +} + +// +// Transform the modes of a layout +// + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +transform_layout(Tuple const& t, F&& f, seq) +{ + return make_layout(f(get(t))...); +} + +template +CUTE_HOST_DEVICE constexpr +auto +transform_layout(Tuple0 const& t0, Tuple1 const& t1, F&& f, seq, seq, seq) +{ + return make_layout(f(get(t0),get(t1))..., get(t0)..., get(t1)...); +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +transform_layout(Tuple const& t, F&& f) +{ + return detail::transform_layout(t, f, make_seq{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +transform_layout(Tuple0 const& t0, Tuple1 const& t1, F&& f) +{ + constexpr int R0 = decltype(rank(t0))::value; + constexpr int R1 = decltype(rank(t1))::value; + constexpr int R = (R0 < R1) ? R0 : R1; + return detail::transform_layout(t0, t1, f, make_seq{}, make_range{}, make_range{}); +} + +// +// Coalesce and Filter +// + +namespace detail { + +// Look at each element and the front of the stack (in order of priority) +// front(NewLayout) get(Layout) +// s0:d0 _1:d1 => continue +// _1:d0 s1:d1 => replace_front s1:d1 +// s0:s1*d1 s1:d1 => replace_front s0*s1:d1 +// s0:d0 s1:d1 => prepend s1:d1 +// +// @pre OldShape and OldStride are flat +template +CUTE_HOST_DEVICE constexpr +auto +bw_coalesce(OldShape const& old_shape, OldStride const& old_stride, + NewShape const& new_shape, NewStride const& new_stride) +{ + if constexpr (I == -1) { + // Base case, we're done + if constexpr (is_constant<1, NewShape>::value) { + return Layout<_1,_0>{}; + } else { + return Layout{new_shape,new_stride}; + } + } else if constexpr (is_constant<1, decltype(get(old_shape))>::value) { + // shape(layout) == _1, skip it and continue + return bw_coalesce(old_shape, old_stride, new_shape, new_stride); + } else if constexpr (is_constant<1, NewShape>::value) { + // Replace our shape-1 with anything (Can only happen on input new_shape/new_stride) + return bw_coalesce(old_shape, old_stride, get(old_shape), get(old_stride)); + } else if constexpr (is_static(new_shape))>::value && + is_constant(old_shape) * get(old_stride) == get<0>(new_stride))>::value) { + // Merge modes because the shapes and strides match + return bw_coalesce(old_shape, old_stride, + replace_front(new_shape, get(old_shape) * get<0>(new_shape)), + replace_front(new_stride, get(old_stride))); + } else { + // Can't replace or merge, so prepend a new mode + return bw_coalesce(old_shape, old_stride, + prepend(new_shape, get(old_shape)), + prepend(new_stride, get(old_stride))); + } + + CUTE_GCC_UNREACHABLE; +} + +// cute::coalesce promises to not change the Layout as a function from integers to codomain. +// It accomplishes this inside of the Layout's domain, but not always outside of the domain. +// Example: (_4,_1):(_1,_0) coalesces to _4:_1. +// detail::coalesce_x preserves the Layout function inside its domain and outside. +// +// @post depth(@a result) <= 1 +// @post for all i, 0 <= i, @a layout(i) == @a result(i) +template +CUTE_HOST_DEVICE constexpr +auto +coalesce_x(Layout const& layout) +{ + auto flat_shape = flatten(layout.shape()); + auto flat_stride = flatten(layout.stride()); + + constexpr int R = decltype(rank(flat_shape))::value; + if constexpr (is_constant<1, decltype(get(flat_shape))>::value) { + return detail::bw_coalesce(flat_shape, flat_stride, Int<2>{}, get(flat_stride)); + } else { + return detail::bw_coalesce(flat_shape, flat_stride, get(flat_shape), get(flat_stride)); + } +} + +// Apply coalesce_x at the terminals of trg_profile +template +CUTE_HOST_DEVICE constexpr +auto +coalesce_x(Layout const& layout, IntTuple const& trg_profile) +{ + if constexpr (is_tuple::value) { + static_assert(tuple_size::value <= Layout::rank); + return cute::transform_layout(layout, trg_profile, [](auto const& l, auto const& t) { return coalesce_x(l,t); }); + } else { + return coalesce_x(layout); + } + + CUTE_GCC_UNREACHABLE; +} + +} // end namespace detail + +// "Simplify" the layout by combining modes that are possible to combine +// Does not respect the shape of the layout, but does preserve total size +// @post size(@a result) == size(@a layout) +// @post depth(@a result) <= 1 +// @post for all i, 0 <= i < size(@a layout), @a layout(i) == @a result(i) +template +CUTE_HOST_DEVICE constexpr +auto +coalesce(Layout const& layout) +{ + auto flat_shape = flatten(layout.shape()); + auto flat_stride = flatten(layout.stride()); + + constexpr int R = decltype(rank(flat_shape))::value; + return detail::bw_coalesce(flat_shape, flat_stride, get(flat_shape), get(flat_stride)); +} + +// Apply coalesce at the terminals of trg_profile +template +CUTE_HOST_DEVICE constexpr +auto +coalesce(Layout const& layout, IntTuple const& trg_profile) +{ + if constexpr (is_tuple::value) { + static_assert(tuple_size::value <= Layout::rank); + return transform_layout(layout, trg_profile, [](auto const& l, auto const& t) { return coalesce(l,t); }); + } else { + return coalesce(layout); + } + + CUTE_GCC_UNREACHABLE; +} + +// Combine static and dynamic modes of a shape. +// @post size(@a result) == size(@a shape) +// @post depth(@a result) <= 1 +template +CUTE_HOST_DEVICE constexpr +auto +coalesce(Shape const& shape) +{ + static_assert(is_integral::value || is_tuple::value); + + return cute::fold_first(flatten(shape), [](auto const& init, auto const& a) { + if constexpr (is_static::value == is_static::value) { + return replace_back(init, back(init) * a); // Both static or both dynamic, coalesce and replace + } else { + return append(init, a); // Can't coalesce, so append + } + }); +} + +// Replace the modes in layout that have a 0-stride with a 1-size +template +CUTE_HOST_DEVICE constexpr +auto +filter_zeros(Layout const& layout) +{ + return make_layout(filter_zeros(layout.stride(), layout.shape()), layout.stride()); +} + +// Replace the modes in layout that correspond to a 0 at the terminals of trg_profile with a 1-size +template +CUTE_HOST_DEVICE constexpr +auto +filter_zeros(Layout const& layout, IntTuple const& trg_profile) +{ + return make_layout(filter_zeros(trg_profile, layout.shape()), layout.stride()); +} + +// Remove all of the 0-strides and 1-sizes +// Return 1-shape if empty +template +CUTE_HOST_DEVICE constexpr +auto +filter(Layout const& layout) +{ + return coalesce(filter_zeros(layout)); +} + +// Apply filter at the terminals of trg_profile +template +CUTE_HOST_DEVICE constexpr +auto +filter(Layout const& layout, IntTuple const& trg_profile) +{ + if constexpr (is_tuple::value) { + static_assert(tuple_size::value <= Layout::rank); + return transform_layout(layout, trg_profile, [](auto const& l, auto const& t) { return filter(l,t); }); + } else { + return filter(layout); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Append, Prepend, Replace +// + +template +CUTE_HOST_DEVICE constexpr +auto +append(Layout const& layout, + Layout const& x = {}) +{ + return make_layout(append(layout.shape(), x.shape()), + append(layout.stride(), x.stride())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +append(Layout const& layout, + Layout const& x = {}) +{ + return make_layout(append(layout.shape(), x.shape()), + append(layout.stride(), x.stride())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +prepend(Layout const& layout, + Layout const& x = {}) +{ + return make_layout(prepend(layout.shape(), x.shape()), + prepend(layout.stride(), x.stride())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +prepend(Layout const& layout, + Layout const& x = {}) +{ + return make_layout(prepend(layout.shape(), x.shape()), + prepend(layout.stride(), x.stride())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +replace(Layout const& layout, + Layout const& x) +{ + return make_layout(replace(layout.shape(), x.shape()), + replace(layout.stride(), x.stride())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +group(Layout const& layout) +{ + return make_layout(group(layout.shape()), + group(layout.stride())); +} + +// +// Composition of two layouts: lhs o rhs +// @post compatible(rhs, result) +// @post result(c) = lhs(rhs(c)) +// for all c in the domain of rhs +// + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +composition_impl(LShape const& lhs_shape, LStride const& lhs_stride, + RShape const& rhs_shape, RStride const& rhs_stride) +{ + if constexpr (is_tuple::value) { + // Apply the right-distributivity of Layout composition + return transform_layout(rhs_shape, rhs_stride, [&](auto const& s, auto const& d) { + return composition_impl(lhs_shape, lhs_stride, s, d); + }); + } else + if constexpr (is_scaled_basis::value) { + // Special case for a ScaledBasis stride + return composition_impl(basis_get(rhs_stride, lhs_shape), basis_get(rhs_stride, lhs_stride), + rhs_shape, basis_value(rhs_stride)); + } else + if constexpr (is_constant<0, RStride>::value) { + // Special case shortcut for any static stride-0 + return Layout{rhs_shape, rhs_stride}; + } else + if constexpr (is_integral::value) { + // Special case shortcut for any integral LShape + return Layout{rhs_shape, rhs_stride * lhs_stride}; + } else + if constexpr (is_constant<1, RStride>::value) { + // Special case shortcut for any static stride-1 + constexpr int R = rank_v; + auto result_shape_0 = take<0,R-1>(lhs_shape); + + // Mod out the rhs_shape from the lhs_shape + auto const [result_shape_1, rest_shape] = fold(result_shape_0, cute::make_tuple(cute::make_tuple(), rhs_shape), + [] (auto const& init, auto const& si) { + return cute::make_tuple(append(get<0>(init), shape_min(abs(si), get<1>(init))), shape_div(get<1>(init), abs(si))); + }); + + // Jump into coalesce and append (rest_shape, get(lhs_stride)) + return detail::bw_coalesce(result_shape_1, lhs_stride, rest_shape, get(lhs_stride)); + } else { + // General case: integral RShape and RStride, tuple LShape and LStride + constexpr int R = rank_v; + auto result_shape_0 = take<0,R-1>(lhs_shape); + auto result_stride_0 = take<0,R-1>(lhs_stride); + + // Divide out the rhs_stride from the lhs_shape + auto const [result_shape_1, rest_stride] = fold(result_shape_0, cute::make_tuple(cute::make_tuple(), rhs_stride), + [] (auto const& init, auto const& di) { + return cute::make_tuple(append(get<0>(init), shape_div(di, get<1>(init))), shape_div(get<1>(init), di)); + }); + + // Apply any lhs_shape changes to the stride + auto result_stride_1 = elem_scale(result_stride_0, shape_div(result_shape_0, result_shape_1)); + + // Mod out the rhs_shape from the lhs_shape + auto const [result_shape_2, rest_shape] = fold(result_shape_1, cute::make_tuple(cute::make_tuple(), rhs_shape), + [] (auto const& init, auto const& si) { + return cute::make_tuple(append(get<0>(init), shape_min(abs(si), get<1>(init))), shape_div(get<1>(init), abs(si))); + }); + + // Jump into coalesce and append (rest_shape, rest_stride * get(lhs_stride)) + return detail::bw_coalesce(result_shape_2, result_stride_1, rest_shape, rest_stride * get(lhs_stride)); + } + + CUTE_GCC_UNREACHABLE; +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +composition(Layout const& lhs, + Layout const& rhs) +{ + auto coprofile = repeat_like(decltype(coshape(rhs)){}, Int<0>{}); + auto flat_lhs = detail::coalesce_x(lhs, coprofile); + return detail::composition_impl(flat_lhs.shape(), flat_lhs.stride(), rhs.shape(), rhs.stride()); +} + +template +CUTE_HOST_DEVICE constexpr +auto +composition(Layout const& lhs, + Tiler const& rhs) +{ + if constexpr (is_tuple::value) { + static_assert(tuple_size::value <= Layout::rank); + // Drop any modes of lhs that aren't hit by rhs + return detail::transform_layout(lhs, rhs, [](auto const& l, auto const& r) { return composition(l,r); }, make_seq::value>{}, seq<>{}, seq<>{}); + } else if constexpr (is_underscore::value) { + return lhs; + } else if constexpr (is_integral::value) { + auto flat_lhs = detail::coalesce_x(lhs); + return detail::composition_impl(flat_lhs.shape(), flat_lhs.stride(), rhs, Int<1>{}); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Complement +// +// Build the complement of a layout. +// @post size(@a result) >= @a cosize_hi / size(filter(@a layout))); +// @post For all i in [1,size(@a result)), +// @a result(i) < @a result(i-1) +// For all j in [0, size(@a layout)), +// @a result(i) != @a layout(j) +// + +namespace detail { + +// @pre @a layout has been filtered (flattened and no stride-0 or size-1 modes). +template +CUTE_HOST_DEVICE constexpr +auto +complement(Shape const& shape, Stride const& stride, CoTarget const& cotarget) +{ + if constexpr (is_constant<0, Stride>::value) { + // Special case for irreducible rank-1 stride-0 layout + return make_layout(coalesce(cotarget)); + } else { + // General case + constexpr int R = rank_v; + static_assert(R == 1 || is_static::value, + "Dynamic-stride complement only for rank-1 layouts"); + + // Should just be a sort and a fold... + // Then we could even handle dynamic strides (but they would destroy all static strides) + auto [shape_, stride_, result_shape_, result_stride] = + fold(make_seq{}, + cute::make_tuple(shape, stride, cute::make_tuple(), cute::make_tuple(Int<1>{})), + [](auto const& init, auto i) + { + auto [shape, stride, result_shape, result_stride] = init; + auto min_stride = cute::min(stride); + auto min_idx = cute::find(stride, min_stride); + auto new_shape = min_stride / get(result_stride); + auto new_stride = min_stride * get(shape); + static_assert(not is_constant<0, decltype(new_shape)>::value, "Non-injective Layout detected in complement."); + + return cute::make_tuple(remove(shape), // Remove the min_idx from shape + remove(stride), // Remove the min_idx from stride + append(result_shape , new_shape ), // new shape = min_stride / last_stride + append(result_stride, new_stride)); // new stride = min_stride * curr_shape + }); + + // Append the last shape mode + auto new_shape = get<0>(stride_) / get(result_stride); // new shape = min_stride / last_stride + static_assert(not is_constant<0, decltype(new_shape)>::value, "Non-injective Layout detected in complement."); + auto result_shape = append(result_shape_, new_shape); + + // Compute the rest_shape and rest_stride + auto new_stride = get<0>(stride_) * get<0>(shape_); // new stride = min_stride * curr_shape + auto rest_shape = coalesce(ceil_div(cotarget, new_stride)); + auto rest_stride = compact_major(rest_shape, new_stride); + + // Coalesce and append (rest_shape, rest_stride) + return coalesce(make_layout(make_shape (result_shape , rest_shape ), + make_stride(result_stride, rest_stride))); + } + + CUTE_GCC_UNREACHABLE; +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +complement(Layout const& layout, CoTarget const& cotarget) +{ + auto filter_layout = filter(layout); + return detail::complement(filter_layout.shape(), filter_layout.stride(), shape(cotarget)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +complement(Layout const& layout) +{ + auto filter_layout = filter(layout); + return detail::complement(filter_layout.shape(), filter_layout.stride(), cosize(filter_layout)); +} + +// +// Right-Inverse and Left-Inverse +// + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +inverse_seq(Shape const& shape, Stride const& stride, seq) +{ + auto next_I = cute::find_if(stride, [](auto a) { return is_constant{}; }); + + if constexpr (next_I == decltype(rank(stride))::value) { + // If not found, return current seq + return seq{}; + } else { + // auto next_stride = get(shape) * get(stride); + // NOTE: Needed for g++-7 + using next_stride = decltype(get(shape) * get(stride)); + + if constexpr (is_static::value && !is_constant::value) { + // If next_stride is static and unique, then continue + return inverse_seq(shape, stride, seq{}); + } else { + // Else return current seq + next_I + return seq{}; + } + } + + CUTE_GCC_UNREACHABLE; +} + +} // end namespace detail + +// +// Build the right-inverse of a layout +// @pre is_static +// @result A layout @a result such that +// @a layout(@a result(i)) == i for all i < size(@a result) +// @result A layout @a result such that +// composition(@a layout, @a result) is identical to make_layout(shape(result)) +// + +template +CUTE_HOST_DEVICE constexpr +auto +right_inverse(Layout const& layout) +{ + auto flat_layout = coalesce(layout); + auto astride = transform_leaf(flat_layout.stride(), abs_fn{}); + + // Find Int<1>{}, the starting stride, and follow the strides to gen inverse_seq + [[maybe_unused]] auto iseq = detail::inverse_seq<1>(flat_layout.shape(), astride, seq<>{}); + + if constexpr (iseq.size() == 0) { + return Layout<_1,_0>{}; // Empty case, nothing found + } else { + // Generate the corresponding new strides and construct + auto rstride = compact_major(flat_layout.shape()); + return make_layout(unwrap(transform(iseq, [&](auto i) { return shape(flat_layout); })), + unwrap(transform(iseq, [&](auto i) { return signum(stride(flat_layout)) * get(rstride); }))); + } + + CUTE_GCC_UNREACHABLE; +} + +CUTE_HOST_DEVICE constexpr +auto +right_inverse(Underscore const& _) +{ + return _; +} + +// +// Build the left-inverse of a layout +// @pre is_static +// @pre @a layout is an injective function +// @result A layout @a result such that +// @a result(@a layout(i)) == i for all i < size(@a layout) +// @result A layout @a result such that +// composition(@a result, @a layout) is identical to make_layout(shape(layout)) +// + +template +CUTE_HOST_DEVICE constexpr +auto +left_inverse(Layout const& layout) +{ + return right_inverse(make_layout(layout, complement(layout))); +} + +CUTE_HOST_DEVICE constexpr +auto +left_inverse(Underscore const& _) +{ + return _; +} + +// +// Max Common Layout +// + +/* Return a layout that points to the maximum number of contiguous elements + * that logically correspond in the layouts of @a a and @a b. + * + * @returns Layout R + * @post For all 0 <= i < size(R), a(R(i)) == i and b(R(i)) == i + */ +template +CUTE_HOST_DEVICE constexpr +auto +max_common_layout(Layout const& a, + Layout const& b) +{ + Layout inv_b = right_inverse(b); + Layout common = coalesce(composition(a, inv_b)); + + // Keep only the static identity component of the common layout + if constexpr (is_static(common))>::value && + is_constant<1, decltype(stride<0>(common))>::value) { + // Truncate to the size of the contiguous vector (static stride-1 mode) + return composition(inv_b, layout<0>(common)); + } else { + return Layout<_1,_0>{}; + } +} + +/* Return Int such that N is the maximum number of contiguous elements + * that logically correspond in the layouts of @a a and @a b. + * + * @returns Int with N >= 1 + * @post For all 0 <= n < N, a(b.get_1d_coord(n)) == n + * (NOTE: Problems with negative strides/coords in this post-condition) + */ +template +CUTE_HOST_DEVICE constexpr +auto +max_common_vector(Layout const& a, + Layout const& b) +{ + Layout common = coalesce(composition(a, right_inverse(b))); + + // Keep only the static identity component of the common layout + if constexpr (is_static(common))>::value && + is_constant<1, decltype(stride<0>(common))>::value) { + // Truncate to the size of the contiguous vector (static stride-1 mode) + return shape<0>(common); + } else { + return Int<1>{}; + } + + CUTE_GCC_UNREACHABLE; +} + +/* Return a layout that distributes ShapeB over ShapeA. + * + * @returns Layout result + * @post evenly_divides(@a b, size(@a result)) + * @post evenly_divides(@a a, @a result) + * @post For all i,j in [0,size(@a result)) with i < j, @a result(i) < @a result(j). Surjective and Ordered. + * @post composition(make_layout(shape(@a a)), @a result) is admissible + * \code + * // Note that 6 does not divide this shape + * Layout layoutA = Layout,Int<14>>>{}; + * + * // Want to tile any 6 elements and don't care where they come from + * Layout dist = domain_distribute(layoutA, Int<6>{}); // (_3,_2):(_1,_15) + * + * // Not guaranteed to find all 6 though... + * CUTE_STATIC_ASSERT_V(Int<6>{} == size(dist)); + * + * Layout result = zipped_divide(layoutA, dist); // (_6,Rest) + * \endcode + */ +template +CUTE_HOST_DEVICE constexpr +auto +domain_distribute(ShapeA const& a, ShapeB const& b) +{ + static_assert(is_integral::value); + static_assert(is_static::value); + + auto flat_shape_a = flatten(shape(a)); + + static_assert(is_static::value); + + // Compute the shape of the result + auto [result_shape, b_rest] = cute::fold(flat_shape_a, cute::make_tuple(cute::tuple<>{}, size(b)), [](auto init, auto a_) { + auto [result, b_] = init; + auto gcd_ = gcd(a_, b_); + return cute::make_tuple(append(result, gcd_), b_ / gcd_); + }); + + // Compute the stride of the result + auto result_stride = compact_major(flat_shape_a); + + return coalesce(make_layout(result_shape, result_stride)); +} + +// +// Kernel (Nullspace) of a Layout +// + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +nullspace_seq(Stride const& stride, seq) +{ + if constexpr (NextI == rank_v) { + return seq{}; + } else + if constexpr (is_constant<0, decltype(get(stride))>::value) { + return detail::nullspace_seq(stride, seq{}); + } else { + return detail::nullspace_seq(stride, seq{}); + } + + CUTE_GCC_UNREACHABLE; +} + +} // end namespace detail + +// +// Build the nullspace of a layout +// @result A layout @a result such that +// size(@a result) == size(@a layout) / size(filter(@a layout)) +// @a layout(@a result(i)) == 0 for all i < size(@a result) +// + +template +CUTE_HOST_DEVICE constexpr +auto +nullspace(Layout const& layout) +{ + auto flat_layout = flatten(layout); + + auto iseq = detail::nullspace_seq<0>(flat_layout.stride(), seq<>{}); + + if constexpr (iseq.size() == 0) { + return Layout<_1,_0>{}; // Empty case, nothing found + } else { + // Generate the corresponding new strides and construct + auto rstride = compact_major(flat_layout.shape()); + return make_layout(unwrap(transform(iseq, [&](auto i) { return shape(flat_layout); })), + unwrap(transform(iseq, [&](auto i) { return get(rstride); }))); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Zip +// + +template +CUTE_HOST_DEVICE constexpr +auto +zip(Layout const& layout) +{ + return make_layout(zip(layout.shape()), + zip(layout.stride())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +zip(Layout const& layoutA, + Layout const& layoutB) +{ + return make_layout(zip(layoutA.shape(), layoutB.shape()), + zip(layoutA.stride(), layoutB.stride())); +} + +// +// Tile unzip +// Logical product and logical divide (on layouts) produce rank-2 results by design. +// Follow the profile of @a tile and zip the rank-2 modes located at the terminals into +// their own mode. +// + +template +CUTE_HOST_DEVICE constexpr +auto +tile_unzip(Layout const& layout, + Tiler const& tiler) +{ + return make_layout(zip2_by(layout.shape(), tiler), + zip2_by(layout.stride(), tiler)); +} + +// +// Logical divide +// + +template +CUTE_HOST_DEVICE constexpr +auto +logical_divide(Layout const& layout, + Layout const& tiler) +{ + return composition(layout, make_layout(tiler, complement(tiler, shape(layout)))); +} + +template +CUTE_HOST_DEVICE constexpr +auto +logical_divide(Layout const& layout, + Tiler const& tiler) +{ + if constexpr (is_tuple::value) { + static_assert(tuple_size::value <= Layout::rank, "logical_divide: Too many modes in tiler."); + return transform_layout(layout, tiler, [](auto const& l, auto const& t) { return logical_divide(l,t); }); + } else if constexpr (is_underscore::value) { + return layout; + } else if constexpr (is_integral::value) { + return logical_divide(layout, make_layout(tiler)); + } + + CUTE_GCC_UNREACHABLE; +} + +// Generalization of ceil_div for Layout lhs +// is effectively the "rest mode" of logical_divide. +// Occurs in the calculation of gridDim, for example, for generalized tilers +// Example: +// dim3 gridDim(size(ceil_div(problem_shape_M, cta_tiler_M)), +// size(ceil_div(problem_shape_N, cta_tiler_N))); +// This does not consider compositional acceptance, so it may be the case that +// ceil_div produces a result while logical_divide (and friends) do not. +template +CUTE_HOST_DEVICE constexpr +auto +ceil_div(Target const& target, + Layout const& tiler) +{ + return shape(complement(tiler, shape(target))); +} + +// +// Convenience operator +// that produces layouts like ((BLK_A,BLK_B,...),(a,b,...,x,y)) +// by gathering the tile modes and residuals into a rank-2 result. +// + +template +CUTE_HOST_DEVICE constexpr +auto +zipped_divide(Layout const& layout, + Tiler const& tiler) +{ + return tile_unzip(logical_divide(layout, tiler), tiler); +} + +// Same as zipped_divide, but unpacks the second mode: ((BLK_A,BLK_B,...),a,b,...,x,y) +template +CUTE_HOST_DEVICE constexpr +auto +tiled_divide(Layout const& layout, + Tiler const& tiler) +{ + auto result = zipped_divide(layout, tiler); + + auto R1 = rank<1>(result); + return result(_, repeat(_)); +} + +// Same as zipped_divide, but unpacks both modes: (BLK_A,BLK_B,...,a,b,...,x,y) +template +CUTE_HOST_DEVICE constexpr +auto +flat_divide(Layout const& layout, + Tiler const& tiler) +{ + auto result = zipped_divide(layout, tiler); + + auto R0 = rank<0>(result); + auto R1 = rank<1>(result); + return result(repeat(_), repeat(_)); +} + +// +// Logical product +// + +template +CUTE_HOST_DEVICE constexpr +auto +logical_product(Layout const& block, + Layout const& tiler) +{ + return make_layout(block, composition(complement(block, size(block)*cosize(tiler)), tiler)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +logical_product(Layout const& block, + Tiler const& tiler) +{ + if constexpr (is_tuple::value) { + static_assert(tuple_size::value <= Layout::rank, "logical_product: Too many modes in tiler."); + return transform_layout(block, tiler, [](auto const& l, auto const& t) { return logical_product(l,t); }); + } else if constexpr (is_underscore::value) { + return block; + } else if constexpr (is_integral::value) { + return logical_product(block, make_layout(tiler)); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Convenience operator +// that produces layouts like ((BLK_A,BLK_B,...),(a,b,...,x,y)) +// by gathering the block modes and products into a rank-2 result. +// + +template +CUTE_HOST_DEVICE constexpr +auto +zipped_product(Layout const& block, + Tiler const& tiler) +{ + return tile_unzip(logical_product(block, tiler), tiler); +} + +// Same as zipped_product, but unpacks the second mode: ((BLK_A,BLK_B,...),a,b,...,x,y) +template +CUTE_HOST_DEVICE constexpr +auto +tiled_product(Layout const& block, + Tiler const& tiler) +{ + auto result = zipped_product(block, tiler); + + auto R1 = rank<1>(result); + return result(_, repeat(_)); +} + +// Same as zipped_product, but unpacks both modes: (BLK_A,BLK_B,...,a,b,...,x,y) +template +CUTE_HOST_DEVICE constexpr +auto +flat_product(Layout const& block, + Tiler const& tiler) +{ + auto result = zipped_product(block, tiler); + + auto R0 = rank<0>(result); + auto R1 = rank<1>(result); + return result(repeat(_), repeat(_)); +} + +// +// Rank-sensitive products +// + +// blocked_product -- Reproduce a block over a tiler. +// Think of every element of "tiler" as a "block" +// and return the layout of the resulting structure. +// @post rank(@a result) == cute::max(rank(@a block), rank(@a tiler)) +template +CUTE_HOST_DEVICE constexpr +auto +blocked_product(Layout const& block, + Layout const& tiler) +{ + constexpr int R = cute::max(rank_v, rank_v); + + auto result = logical_product(append(block), append(tiler)); + + return coalesce(zip(get<0>(result), get<1>(result)), tuple_repeat(Int<1>{})); +} + +// raked_product -- Reproduce a block over a tiler with block-interleaving. +// Think of every element of "tiler" as a "block", interleave those blocks, +// and return the layout of the resulting structure. +// @post rank(@a result) == cute::max(rank(@a block), rank(@a tiler)) +template +CUTE_HOST_DEVICE constexpr +auto +raked_product(Layout const& block, + Layout const& tiler) +{ + constexpr int R = cute::max(rank_v, rank_v); + + auto result = logical_product(append(block), append(tiler)); + + return coalesce(zip(get<1>(result), get<0>(result)), tuple_repeat(Int<1>{})); +} + +// tile_to_shape -- Perform a product of a layout so that the result matches a target shape. +// This is similar to blocked_product, but specifies the result shape instead of the +// product shape, which is more convenient in certain circumstances. +// @param block The layout to repeat +// @param trg_shape The target shape of the result +// @param ord_shape The order of the modes of @a trg_shape to tile @a layout with. +// Defaults to GenColMajor, so @a layout will repeat +// across the first mode first, the second mode second, etc +// E.g. Step<_2,_1,_3> will cause @a layout to repeat +// across the second mode first, the first mode second, and the third mode last. +// @pre rank(@a block) <= rank(@a trg_shape) +// @post compatible(@a trg_shape, shape(@a result)) +template +CUTE_HOST_DEVICE constexpr +auto +tile_to_shape(Layout const& block, + TrgShape const& trg_shape, + ModeOrder const& ord_shape = {}) +{ + CUTE_STATIC_ASSERT_V(rank(block) <= rank(trg_shape), "Rank of layout must be <= rank of target shape."); + constexpr int R = rank_v; + + auto padded_block = append(block); + + auto block_shape = product_each(shape(padded_block)); + auto target_shape = product_each(shape(trg_shape)); + + // Assert proper division + if constexpr (is_static::value) { + CUTE_STATIC_ASSERT_V(evenly_divides(target_shape, block_shape), + "tile_to_shape: block shape does not divide the target shape."); + } + + auto product_shape = ceil_div(target_shape, block_shape); + + return coalesce(blocked_product(padded_block, make_ordered_layout(product_shape, ord_shape)), product_shape); +} + +// +// Upcast +// For stride-1 mode, divide size by N. Divide all other strides by N. +// + +template +CUTE_HOST_DEVICE constexpr +auto +upcast(Shape const& shape, Stride const& stride) +{ + if constexpr (is_tuple::value) { // tuple stride + return transform_layout(shape, stride, [](auto const& s, auto const& d) { return upcast(s,d); }); + } else if constexpr (is_constant<0, Stride>::value) { // static-0 stride + return Layout{shape,stride}; + } else if constexpr (is_static::value) { // static stride + return make_layout(shape_div(shape, shape_div(Int{}, abs(stride))), + shape_div(stride, Int{})); + } else { // dynamic stride + // assume dynamic strides are larger than N and divisible + // assert(stride % N == 0); + return make_layout(shape, safe_div(stride, Int{})); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +upcast(Layout const& layout) +{ + return upcast(layout.shape(), layout.stride()); +} + +// +// Downcast +// For stride-1 mode, multiply size by N. Multiply all other strides by N. +// + +template +CUTE_HOST_DEVICE constexpr +auto +downcast(Shape const& shape, Stride const& stride) +{ + if constexpr (is_tuple::value) { + return transform_layout(shape, stride, [](auto const& s, auto const& d) { return downcast(s,d); }); + } else if constexpr (is_constant<1, Stride>::value || is_constant<-1, Stride>::value) { + return make_layout(shape * Int{}, stride); + } else { + return make_layout(shape, stride * Int{}); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +downcast(Layout const& layout) +{ + CUTE_STATIC_ASSERT(has_int1::value, "Downcast requires adjacent elements"); + return downcast(layout.shape(), layout.stride()); +} + +// +// Recast +// + +template +CUTE_HOST_DEVICE constexpr +auto +recast_layout(Layout const& layout) +{ + using scale = decltype(trait_ratio(sizeof_bits{}, sizeof_bits{})); + if constexpr (scale::num == 1 && scale::den == 1) { + return layout; + } + else if constexpr (scale::num == 1) { + return downcast(layout); + } + else if constexpr (scale::den == 1) { + return upcast(layout); + } + else { + return downcast(upcast(layout)); + } + + CUTE_GCC_UNREACHABLE; +} + +// Determine the maximum alignment of a Layout. +// The maximum alignment is the largest N for which upcast(layout) will compile. +// upcast(layout) compiles when the static shapes and strides pass divisibility checks. +// Therefore, upcast(layout) will also compile for all divisors M of N. +// Note that this only considers the static shapes and strides of the Layout +// in symmetry with upcast only checking against static shapes and strides and assuming all +// dynamic shapes and strides are large and multiples of N. +template +CUTE_HOST_DEVICE constexpr +auto +max_alignment(Layout const& layout) +{ + auto flat_layout = coalesce(layout); + auto static_shape = transform( shape(flat_layout), [](auto s){ return conditional_return::value>(s, Int<1>{}); }); + auto static_stride = transform(stride(flat_layout), [](auto d){ return conditional_return::value>(d, Int<0>{}); }); + auto filter_layout = make_layout(static_shape, static_stride); + auto permuted = logical_divide(filter_layout, right_inverse(filter_layout)); + return gcd(size<0>(permuted), stride<1>(permuted)); +} + +// +// Display utilities +// + +template +CUTE_HOST_DEVICE void print(Layout const& layout) +{ + print(layout.shape()); print(":"); print(layout.stride()); +} + +#if !defined(__CUDACC_RTC__) +template +CUTE_HOST std::ostream& operator<<(std::ostream& os, Layout const& layout) +{ + return os << shape(layout) << ":" << stride(layout); +} +#endif + +// Generic 2D Layout to console table +template +CUTE_HOST_DEVICE +void +print_layout(Layout const& layout) // (m,n) -> idx +{ + CUTE_STATIC_ASSERT_V(rank(layout) == Int<2>{}); + + int idx_width = num_digits(cosize(layout)) + 2; + const char* delim = "+-----------------------"; + + print(layout); print("\n"); + + // Column indices + print(" "); + for (int n = 0; n < size<1>(layout); ++n) { printf(" %*d ", idx_width-2, n); } + printf("\n"); + + // Print out A m-by-n + for (int m = 0; m < size<0>(layout); ++m) { + // Header + print(" "); + for (int n = 0; n < size<1>(layout); ++n) { printf("%.*s", idx_width+1, delim); } + printf("+\n"); + // Values + printf("%2d ", m); // Row indices + for (int n = 0; n < size<1>(layout); ++n) { printf("| %*d ", idx_width-2, int(layout(m,n))); } + printf("|\n"); + } + // Footer + print(" "); + for (int n = 0; n < size<1>(layout); ++n) { printf("%.*s", idx_width+1, delim); } + printf("+\n"); +} + +// Generic ThrVal 2D Layout to console table +template +CUTE_HOST_DEVICE +void +print_layout(Layout const& layout, ThrID const& thrid) // (m,n) -> (tid,vid) and tid -> thr_idx +{ + CUTE_STATIC_ASSERT_V(rank(layout) == Int<2>{}); + + print(layout); print("\n"); + print(thrid); print("\n"); + + // Print out m-by-n + for (int m = 0; m < size<0>(layout); ++m) { + // Header + for (int n = 0; n < size<1>(layout); ++n) printf("+------"); + printf("+\n"); + // Values + for (int n = 0; n < size<1>(layout); ++n) printf("|%03d-%02d", int(thrid(layout(m,n) % size(thrid))), int(layout(m,n) / size(thrid))); + printf("|\n"); + } + // Footer + for (int n = 0; n < size<1>(layout); ++n) printf("+------"); + printf("+\n"); +} + +struct TikzColor_White { + CUTE_HOST_DEVICE char const* + operator()(int idx) const { + return "white"; + } +}; + +struct TikzColor_BWx8 { + CUTE_HOST_DEVICE char const* + operator()(int idx) const { + static char const* color_map[8] = {"black!00", "black!40", "black!20", "black!60", + "black!10", "black!50", "black!30", "black!70"}; + return color_map[idx % 8]; + } +}; + +struct TikzColor_TV { + CUTE_HOST_DEVICE char const* + operator()(int tid, int vid) const { + static char const* color_map[8] = {"{rgb,255:red,175;green,175;blue,255}", + "{rgb,255:red,175;green,255;blue,175}", + "{rgb,255:red,255;green,255;blue,175}", + "{rgb,255:red,255;green,175;blue,175}", + "{rgb,255:red,210;green,210;blue,255}", + "{rgb,255:red,210;green,255;blue,210}", + "{rgb,255:red,255;green,255;blue,210}", + "{rgb,255:red,255;green,210;blue,210}"}; + return color_map[tid % 8]; + } +}; + +// Generic 2D Layout to LaTeX printer +template +CUTE_HOST_DEVICE +void +print_latex(LayoutA const& layout_a, // (m,n) -> idx + TikzColorFn color = {}) // lambda(idx) -> tikz color string +{ + CUTE_STATIC_ASSERT_V(rank(layout_a) <= Int<2>{}); + auto layout = append<2>(layout_a, Layout<_1,_0>{}); + + // Commented print(layout) + printf("%% Layout: "); print(layout); printf("\n"); + // Header + printf("\\documentclass[convert]{standalone}\n" + "\\usepackage{tikz}\n\n" + "\\begin{document}\n" + "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},every node/.style={minimum size=1cm, outer sep=0pt}]\n\n"); + + // Layout + for (int i = 0; i < size<0>(layout); ++i) { + for (int j = 0; j < size<1>(layout); ++j) { + int idx = layout(i,j); + printf("\\node[fill=%s] at (%d,%d) {%d};\n", + color(idx), i, j, idx); + } + } + // Grid + printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (0,0) grid (%d,%d);\n\n", + int(size<0>(layout)), int(size<1>(layout))); + // Labels + for (int i = 0, j = -1; i < size<0>(layout); ++i) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, i); + } + for (int i = -1, j = 0; j < size<1>(layout); ++j) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, j); + } + + // Footer + printf("\\end{tikzpicture}\n" + "\\end{document}\n"); +} + +// Generic ThrVal 2D Layout to LaTeX TikZ +template +CUTE_HOST_DEVICE +void +print_latex(Layout const& layout, // (m,n) -> (tid,vid) + ThrID const& thr, // tid -> thr_idx + TikzColorFn color = {}) // lambda(thr_idx,val_idx) -> tikz color string +{ + CUTE_STATIC_ASSERT_V(rank(layout) == Int<2>{}); + + // Commented prints + printf("%% Layout: "); print(layout); printf("\n"); + printf("%% ThrID : "); print(thr); printf("\n"); + // Header + printf("\\documentclass[convert]{standalone}\n" + "\\usepackage{tikz}\n\n" + "\\begin{document}\n" + "\\begin{tikzpicture}[x={(0cm,-1cm)},y={(1cm,0cm)},every node/.style={minimum size=1cm, outer sep=0pt}]\n\n"); + + // Layout + for (int i = 0; i < size<0>(layout); ++i) { + for (int j = 0; j < size<1>(layout); ++j) { + int thrid = layout(i,j) % size(thr); + int val_idx = layout(i,j) / size(thr); + int thr_idx = thr(thrid); + + printf("\\node[fill=%s] at (%d,%d) {\\shortstack{T%d \\\\ V%d}};\n", + color(thr_idx, val_idx), + i, j, + thr_idx, val_idx); + } + } + // Grid + printf("\\draw[color=black,thick,shift={(-0.5,-0.5)}] (0,0) grid (%d,%d);\n\n", + int(size<0>(layout)), int(size<1>(layout))); + // Labels + for (int i = 0, j = -1; i < size<0>(layout); ++i) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, i); + } + for (int j = 0, i = -1; j < size<1>(layout); ++j) { + printf("\\node at (%d,%d) {\\Large{\\texttt{%d}}};\n", i, j, j); + } + + // Footer + printf("\\end{tikzpicture}\n" + "\\end{document}\n"); +} + +} // end namespace cute diff --git a/include/cute/layout_composed.hpp b/include/cute/layout_composed.hpp new file mode 100644 index 0000000000..26ae8dc76c --- /dev/null +++ b/include/cute/layout_composed.hpp @@ -0,0 +1,661 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTE_HOST_DEVICE, CUTE_GCC_UNREACHABLE +#include // cute::tuple +#include // cute::true_type, cute::false_type, cute::Int + +/* This implements a ComposedLayout of the form + * LayoutA o Offset o LayoutB + * and is useful in cases where composition() does not or cannot apply to LayoutA and LayoutB. + * For example, when the "divisibility condition" in shape_div is violated in composition(LayoutA, LayoutB). + * + * This ComposedLayout provides similar functionality to Layout including tiling, partitioning, + * coordinate-to-index mapping and layout manipulations, but is not considered a "normal" layout. + * For example, this layout provides shape() and size() functions, but does not provide stride() functions. + * Mostly, the similar functionality is accomplished by applying each operation to LayoutB only + * as LayoutB defines the domain. + */ + +namespace cute +{ + +// A Layout of non-trivially composable functions: F o I o L +template +struct ComposedLayout : private cute::tuple // EBO for static layouts +{ + CUTE_HOST_DEVICE constexpr + ComposedLayout(LayoutA const& layoutA = {}, + Offset const& offset = {}, + LayoutB const& layoutB = {}) + : cute::tuple(layoutA, offset, layoutB) + {} + + // + // Accessors + // + + static constexpr int rank = LayoutB::rank; + + CUTE_HOST_DEVICE constexpr + decltype(auto) + layout_a() const { + return get<0>(static_cast const&>(*this)); + } + + CUTE_HOST_DEVICE constexpr + decltype(auto) + offset() const { + return get<1>(static_cast const&>(*this)); + } + + CUTE_HOST_DEVICE constexpr + decltype(auto) + layout_b() const { + return get<2>(static_cast const&>(*this)); + } + + CUTE_HOST_DEVICE constexpr + decltype(auto) + layout() const { + return *this; + } + + CUTE_HOST_DEVICE constexpr + decltype(auto) + shape() const { + return layout_b().shape(); + } + + // Doesn't really make sense to ask for the strides of this "layout" + CUTE_HOST_DEVICE constexpr + decltype(auto) + stride() const = delete; + + // + // Mappings + // + + // Map a logical coordinate to a linear index (Coord has no Underscore slice operators) + // OR + // Slice the layout and return the sublayout (Coord has an Underscore slice op) + template + CUTE_HOST_DEVICE constexpr + auto + operator()(Coord const& coord) const { + if constexpr (has_underscore::value) { + return slice(coord, *this); + } else { + return layout_a()(offset() + layout_b()(coord)); // (A o O o B)(c) + } + + CUTE_GCC_UNREACHABLE; + } + + // Convenience function for multi-dimensional coordinates + template + CUTE_HOST_DEVICE constexpr + auto + operator()(Coord0 const& c0, Coord1 const& c1, Coords const&... cs) const { + return operator()(make_coord(c0,c1,cs...)); + } + + // + // Compose + // + + template + CUTE_HOST_DEVICE constexpr + auto + compose(OtherLayout const& other) const { + return composition(*this, other); + } + + template + CUTE_HOST_DEVICE constexpr + auto + compose(Layouts const&... layouts) const { + return composition(*this, make_tile(layouts...)); + } + + template + CUTE_HOST_DEVICE constexpr + auto + with_shape(OtherShape const& shape) const { + return composition(*this, make_layout(shape)); + } + + template + CUTE_HOST_DEVICE constexpr + auto + with_shape(Shapes const&... shapes) const { + return composition(*this, make_layout(make_shape(shapes...))); + } + + // + // Tile + // + + template + CUTE_HOST_DEVICE constexpr + auto + tile(OtherLayout const& other) const { + return tiled_divide(*this, other); + } + + template + CUTE_HOST_DEVICE constexpr + auto + tile(Layouts const&... layouts) const { + return tiled_divide(*this, make_tile(layouts...)); + } + + // Equality, return a static or dynamic boolean + template + CUTE_HOST_DEVICE constexpr + auto + operator==(ComposedLayout const& other) const { + return this->layout_a() == other.layout_a() && + this->layout_b() == other.layout_b() && + this->offset() == other.offset(); + } +}; + +template +struct is_layout> : true_type {}; + +template +struct is_composed_layout : false_type {}; +template +struct is_composed_layout> : true_type {}; + +// +// Constructors +// + +template +CUTE_HOST_DEVICE constexpr +auto +make_composed_layout(LayoutA const& layoutA, + Offset const& offset, + LayoutB const& layoutB) +{ + return ComposedLayout{layoutA, offset, layoutB}; +} + +// +// Utilities +// + +// Return the layout of a mode +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +layout(ComposedLayout const& clayout) +{ + return composition(clayout.layout_a(), clayout.offset(), layout(clayout.layout_b())); +} + +// Return the shape of a mode +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +shape(ComposedLayout const& layout) +{ + return shape(layout.layout_b()); +} + +// Doesn't make sense to directly ask for the strides of this "layout" +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +stride(ComposedLayout const& layout) = delete; + +// Return the number of elements in a mode +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +size(ComposedLayout const& layout) +{ + return size(layout.layout_b()); +} + +// Return the number of modes +template +CUTE_HOST_DEVICE constexpr +auto +rank(ComposedLayout const& layout) +{ + return rank(layout.layout_b()); +} + +// Return the depth of the layout +template +CUTE_HOST_DEVICE constexpr +auto +depth(ComposedLayout const& layout) +{ + return depth(layout.layout_b()); +} + +// Return the codomain size of a mode +template +CUTE_HOST_DEVICE constexpr +auto +cosize(ComposedLayout const& layout) +{ + return cosize(layout.layout_b()); +} + +// +// Operations to manipulate Layouts like a tuple of pairs +// + +template +CUTE_HOST_DEVICE constexpr +auto +get(ComposedLayout const& a) +{ + return composition(a.layout_a(), a.offset(), get(a.layout_b())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +take(ComposedLayout const& a) +{ + return composition(a.layout_a(), a.offset(), take(a.layout_b())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +flatten(ComposedLayout const& a) +{ + return composition(a.layout_a(), a.offset(), flatten(a.layout_b())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +append(ComposedLayout const& a, X const& x) +{ + return composition(a.layout_a(), a.offset(), append(a.layout_b(), x)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +group(ComposedLayout const& a) +{ + return composition(a.layout_a(), a.offset(), group(a.layout_b())); +} + +// +// Slice a ComposedLayout +// + +template +CUTE_HOST_DEVICE constexpr +auto +slice_and_offset(Coord const& coord, ComposedLayout const& layout) +{ + auto [slice, offset] = slice_and_offset(coord, layout.layout_b()); + return cute::make_tuple(ComposedLayout{layout.layout_a(), layout.offset() + offset, slice}, Int<0>{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +slice(Coord const& coord, ComposedLayout const& layout) +{ + return get<0>(slice_and_offset(coord, layout)); +} + +// Compute a pointer offset and (potentially modified) layout from a coordinate +// For composed layout tensors the offset is accumulated in the layout itself while pointer is not updated +template +CUTE_HOST_DEVICE constexpr +auto +domain_offset(Coord const& coord, ComposedLayout const& layout) +{ + return cute::make_tuple(ComposedLayout{layout.layout_a(), layout.offset() + layout.layout_b()(coord), layout.layout_b()}, Int<0>{}); +} + +// +// composition +// + +template +CUTE_HOST_DEVICE constexpr +auto +composition(LayoutA const& layoutA, + Offset const& offset, + LayoutB const& layoutB) +{ + return ComposedLayout{layoutA, offset, layoutB}; +} + +template +CUTE_HOST_DEVICE constexpr +auto +composition(ComposedLayout const& a, + Tiler const& b) +{ + return composition(a.layout_a(), a.offset(), composition(a.layout_b(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +composition(Layout const& a, + ComposedLayout const& b) +{ + CUTE_STATIC_ASSERT_V(b.offset() == Int<0>{}, "Require offset == 0."); + + return composition(composition(a, b.layout_a()), b.layout_b()); +} + +// +// complement +// + +template +CUTE_HOST_DEVICE constexpr +auto +complement(ComposedLayout const& layout, CoTarget const& cotarget) +{ + return complement(layout.layout_b(), cotarget); +} + +template +CUTE_HOST_DEVICE constexpr +auto +complement(ComposedLayout const& layout) +{ + return complement(layout, cosize(layout)); +} + +// +// inverse +// + +template +CUTE_HOST_DEVICE constexpr +auto +right_inverse(ComposedLayout const& layout) +{ + return composition(right_inverse(layout.layout_b()), right_inverse(layout.offset()), right_inverse(layout.layout_a())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +left_inverse(ComposedLayout const& layout) +{ + return composition(left_inverse(layout.layout_b()), left_inverse(layout.offset()), left_inverse(layout.layout_a())); +} + +// +// Other operations +// + +template +CUTE_HOST_DEVICE constexpr +auto +zip(ComposedLayout const& a) +{ + return composition(a.layout_a(), a.offset(), zip(a.layout_b())); +} + +// Partitions + +template +CUTE_HOST_DEVICE constexpr +auto +logical_divide(ComposedLayout const& a, + Tiler const& b) +{ + return composition(a.layout_a(), a.offset(), logical_divide(a.layout_b(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tile_unzip(ComposedLayout const& a, + Tiler const& b) +{ + return composition(a.layout_a(), a.offset(), tile_unzip(a.layout_b(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tiled_divide(ComposedLayout const& a, + Tiler const& b) +{ + return composition(a.layout_a(), a.offset(), tiled_divide(a.layout_b(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +zipped_divide(ComposedLayout const& a, + Tiler const& b) +{ + return composition(a.layout_a(), a.offset(), zipped_divide(a.layout_b(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +flat_divide(ComposedLayout const& a, + Tiler const& b) +{ + return composition(a.layout_a(), a.offset(), flat_divide(a.layout_b(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +logical_product(ComposedLayout const& a, + Tiler const& b) +{ + return composition(a.layout_a(), a.offset(), logical_product(a.layout_b(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +zipped_product(ComposedLayout const& a, + Tiler const& b) +{ + return composition(a.layout_a(), a.offset(), zipped_product(a.layout_b(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tiled_product(ComposedLayout const& a, + Tiler const& b) +{ + return composition(a.layout_a(), a.offset(), tiled_product(a.layout_b(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +flat_product(ComposedLayout const& a, + Tiler const& b) +{ + return composition(a.layout_a(), a.offset(), flat_product(a.layout_b(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +blocked_product(ComposedLayout const& a, + Tiler const& b) +{ + return composition(a.layout_a(), a.offset(), blocked_product(a.layout_b(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +raked_product(ComposedLayout const& a, + Tiler const& b) +{ + return composition(a.layout_a(), a.offset(), raked_product(a.layout_b(), b)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +tile_to_shape(ComposedLayout const& layout, + Shape const& trg_shape, + ModeOrder const& ord_shape = {}) +{ + return composition(layout.layout_a(), layout.offset(), tile_to_shape(layout.layout_b(), trg_shape, ord_shape)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +filter(ComposedLayout const& layout, Shape const& trg_profile) +{ + return composition(layout.layout_a(), layout.offset(), filter(layout.layout_b(), trg_profile)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +coalesce(ComposedLayout const& layout) +{ + return composition(layout.layout_a(), layout.offset(), coalesce(layout.layout_b())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +coalesce(ComposedLayout const& layout, Shape const& trg_profile) +{ + return composition(layout.layout_a(), layout.offset(), coalesce(layout.layout_b(), trg_profile)); +} + + +// +// Upcast and Downcast +// + +template +CUTE_HOST_DEVICE constexpr +auto +upcast(ComposedLayout const& layout) +{ + return composition(upcast(layout.layout_a()), upcast(layout.offset()), upcast(layout.layout_b())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +downcast(ComposedLayout const& layout) +{ + return composition(downcast(layout.layout_a()), downcast(layout.offset()), downcast(layout.layout_b())); +} + + +template +CUTE_HOST_DEVICE constexpr +auto +recast_layout(ComposedLayout const& layout) +{ + using scale = decltype(trait_ratio(sizeof_bits{}, sizeof_bits{})); + if constexpr (scale::num == 1 && scale::den == 1) { + return layout; + } + else if constexpr (scale::num == 1) { + return downcast(layout); + } + else if constexpr (scale::den == 1) { + return upcast(layout); + } + else { + return downcast(upcast(layout)); + } + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +max_alignment(ComposedLayout const& layout) +{ + // Do not attempt for general ComposedLayouts + //return gcd(max_alignment(layout.layout_a()), max_alignment(layout.offset()), max_alignment(layout.layout_b())); + return Int<1>{}; +} + +template +CUTE_HOST_DEVICE constexpr +auto +nullspace(ComposedLayout const& layout) +{ + // Do not attempt for general ComposedLayouts + return Layout<_1,_0>{}; +} + +// +// Display utilities +// + +template +CUTE_HOST_DEVICE void print(ComposedLayout const& layout) +{ + print(layout.layout_a()); print(" o "); print(layout.offset()); print(" o "); print(layout.layout_b()); +} + +#if !defined(__CUDACC_RTC__) +template +CUTE_HOST std::ostream& operator<<(std::ostream& os, ComposedLayout const& layout) +{ + return os << layout.layout_a() << " o " << layout.offset() << " o " << layout.layout_b(); +} +#endif + +} // end namespace cute diff --git a/include/cute/numeric/arithmetic_tuple.hpp b/include/cute/numeric/arithmetic_tuple.hpp new file mode 100644 index 0000000000..2e46905719 --- /dev/null +++ b/include/cute/numeric/arithmetic_tuple.hpp @@ -0,0 +1,556 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +#include +#include +#include +#include +#include + +namespace cute +{ + +template +struct ArithmeticTuple : tuple +{ + template + CUTE_HOST_DEVICE constexpr + ArithmeticTuple(ArithmeticTuple const& u) + : tuple(static_cast const&>(u)) {} + + template + CUTE_HOST_DEVICE constexpr + ArithmeticTuple(tuple const& u) + : tuple(u) {} + + template + CUTE_HOST_DEVICE constexpr + ArithmeticTuple(U const&... u) + : tuple(u...) {} +}; + +template +struct is_tuple> : true_type {}; + +template +struct is_flat> : is_flat> {}; + +template +CUTE_HOST_DEVICE constexpr +auto +make_arithmetic_tuple(T const&... t) { + return ArithmeticTuple(t...); +} + +template +CUTE_HOST_DEVICE constexpr +auto +as_arithmetic_tuple(T const& t) { + if constexpr (is_tuple::value) { + return detail::tapply(t, [](auto const& x){ return as_arithmetic_tuple(x); }, + [](auto const&... a){ return make_arithmetic_tuple(a...); }, + tuple_seq{}); + } else { + return t; + } +} + +// +// Numeric operators +// + +// Addition +template +CUTE_HOST_DEVICE constexpr +auto +operator+(ArithmeticTuple const& t, ArithmeticTuple const& u) { + constexpr int R = cute::max(int(sizeof...(T)), int(sizeof...(U))); + return transform_apply(append(t,Int<0>{}), append(u,Int<0>{}), plus{}, [](auto const&... a){ return make_arithmetic_tuple(a...); }); +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator+(ArithmeticTuple const& t, tuple const& u) { + return t + ArithmeticTuple(u); +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator+(tuple const& t, ArithmeticTuple const& u) { + return ArithmeticTuple(t) + u; +} + +// Subtraction +template +CUTE_HOST_DEVICE constexpr +auto +operator-(ArithmeticTuple const& t, ArithmeticTuple const& u) { + constexpr int R = cute::max(int(sizeof...(T)), int(sizeof...(U))); + return transform_apply(append(t,Int<0>{}), append(u,Int<0>{}), minus{}, [](auto const&... a){ return make_arithmetic_tuple(a...); }); +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator-(ArithmeticTuple const& t, tuple const& u) { + return t - ArithmeticTuple(u); +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator-(tuple const& t, ArithmeticTuple const& u) { + return ArithmeticTuple(t) - u; +} + +// Negation +template +CUTE_HOST_DEVICE constexpr +auto +operator-(ArithmeticTuple const& t) { + return transform_apply(t, negate{}, [](auto const&... a){ return make_arithmetic_tuple(a...); }); +} + +// +// Special cases +// + +template +CUTE_HOST_DEVICE constexpr +ArithmeticTuple const& +operator+(C, ArithmeticTuple const& u) { + static_assert(t == 0, "Arithmetic tuple op+ error!"); + return u; +} + +template +CUTE_HOST_DEVICE constexpr +ArithmeticTuple const& +operator+(ArithmeticTuple const& t, C) { + static_assert(u == 0, "Arithmetic tuple op+ error!"); + return t; +} + +template +CUTE_HOST_DEVICE constexpr +ArithmeticTuple const& +operator-(C, ArithmeticTuple const& u) { + static_assert(t == 0, "Arithmetic tuple op- error!"); + return -u; +} + +template +CUTE_HOST_DEVICE constexpr +ArithmeticTuple const& +operator-(ArithmeticTuple const& t, C) { + static_assert(u == 0, "Arithmetic tuple op- error!"); + return t; +} + +// +// ArithmeticTupleIterator +// + +template +struct ArithmeticTupleIterator +{ + using value_type = ArithTuple; + using element_type = ArithTuple; + using reference = ArithTuple; + + ArithTuple coord_; + + CUTE_HOST_DEVICE constexpr + ArithmeticTupleIterator(ArithTuple const& coord = {}) : coord_(coord) {} + + CUTE_HOST_DEVICE constexpr + ArithTuple operator*() const { return coord_; } + + template + CUTE_HOST_DEVICE constexpr + auto operator[](Coord const& c) const { return *(*this + c); } + + template + CUTE_HOST_DEVICE constexpr + auto operator+(Coord const& c) const { + return ArithmeticTupleIterator>(coord_ + c); + } +}; + +template +CUTE_HOST_DEVICE constexpr +auto +make_inttuple_iter(Tuple const& t) { + return ArithmeticTupleIterator(as_arithmetic_tuple(t)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +make_inttuple_iter(T0 const& t0, T1 const& t1, Ts const&... ts) { + return make_inttuple_iter(cute::make_tuple(t0, t1, ts...)); +} + +// +// ArithmeticTuple "basis" elements +// A ScaledBasis is a (at least) rank-N+1 ArithmeticTuple: +// (_0,_0,...,T,_0,...) +// with value T in the Nth mode + +template +struct ScaledBasis : private tuple +{ + CUTE_HOST_DEVICE constexpr + ScaledBasis(T const& t = {}) : tuple(t) {} + + CUTE_HOST_DEVICE constexpr + decltype(auto) value() { return get<0>(static_cast &>(*this)); } + CUTE_HOST_DEVICE constexpr + decltype(auto) value() const { return get<0>(static_cast const&>(*this)); } + + CUTE_HOST_DEVICE static constexpr + auto mode() { return Int{}; } +}; + +template +struct is_scaled_basis : false_type {}; +template +struct is_scaled_basis> : true_type {}; + +template +struct is_integral> : true_type {}; + +// Get the scalar T out of a ScaledBasis +template +CUTE_HOST_DEVICE constexpr auto +basis_value(SB const& e) +{ + if constexpr (is_scaled_basis::value) { + return basis_value(e.value()); + } else { + return e; + } + CUTE_GCC_UNREACHABLE; +} + +// Apply the N... pack to another Tuple +template +CUTE_HOST_DEVICE decltype(auto) +basis_get(SB const& e, Tuple&& t) +{ + if constexpr (is_scaled_basis::value) { + return basis_get(e.value(), get(static_cast(t))); + } else { + return static_cast(t); + } + CUTE_GCC_UNREACHABLE; +} + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +to_atuple_i(T const& t, seq) { + return make_arithmetic_tuple((void(I),Int<0>{})..., t); +} + +} // end namespace detail + +// Turn a ScaledBases into a rank-N+1 ArithmeticTuple +// with N prefix 0s: (_0,_0,...N...,_0,T) +template +CUTE_HOST_DEVICE constexpr +auto +as_arithmetic_tuple(ScaledBasis const& t) { + return detail::to_atuple_i(as_arithmetic_tuple(t.value()), make_seq{}); +} + +namespace detail { + +template +struct Basis; + +template <> +struct Basis<> { + using type = Int<1>; +}; + +template +struct Basis { + using type = ScaledBasis::type, N>; +}; + +} // end namespace detail + +// Shortcut for writing ScaledBasis, N0>, N1>, ...> +// E<> := _1 +// E<0> := (_1,_0,_0,...) +// E<1> := (_0,_1,_0,...) +// E<0,0> := ((_1,_0,_0,...),_0,_0,...) +// E<0,1> := ((_0,_1,_0,...),_0,_0,...) +// E<1,0> := (_0,(_1,_0,_0,...),_0,...) +// E<1,1> := (_0,(_0,_1,_0,...),_0,...) +template +using E = typename detail::Basis::type; + +template +CUTE_HOST_DEVICE constexpr +auto +make_basis_like(Shape const& shape) +{ + if constexpr (is_integral::value) { + return Int<1>{}; + } else { + // Generate bases for each rank of shape + return transform(tuple_seq{}, shape, [](auto I, auto si) { + // Generate bases for each rank of si and add an i on front + using I_type = decltype(I); + return transform_leaf(make_basis_like(si), [](auto e) { + // MSVC has trouble capturing variables as constexpr, + // so that they can be used as template arguments. + // This is exactly what the code needs to do with i, unfortunately. + // The work-around is to define i inside the inner lambda, + // by using just the type from the enclosing scope. + constexpr int i = I_type::value; + return ScaledBasis{}; + }); + }); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Arithmetic +// + +template +CUTE_HOST_DEVICE constexpr +auto +safe_div(ScaledBasis const& b, U const& u) +{ + auto t = safe_div(b.value(), u); + return ScaledBasis{t}; +} + +template +CUTE_HOST_DEVICE constexpr +auto +shape_div(ScaledBasis const& b, U const& u) +{ + auto t = shape_div(b.value(), u); + return ScaledBasis{t}; +} + +// Equality +template +CUTE_HOST_DEVICE constexpr +auto +operator==(ScaledBasis const& t, ScaledBasis const& u) { + return bool_constant{} && t.value() == u.value(); +} + +// Not equal to anything else +template +CUTE_HOST_DEVICE constexpr +false_type +operator==(ScaledBasis const&, U const&) { + return {}; +} + +template +CUTE_HOST_DEVICE constexpr +false_type +operator==(T const&, ScaledBasis const&) { + return {}; +} + +// Abs +template +CUTE_HOST_DEVICE constexpr +auto +abs(ScaledBasis const& e) { + return ScaledBasis{abs(e.value())}; +} + +// Multiplication +template +CUTE_HOST_DEVICE constexpr +auto +operator*(A const& a, ScaledBasis const& e) { + auto r = a * e.value(); + return ScaledBasis{r}; +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator*(ScaledBasis const& e, B const& b) { + auto r = e.value() * b; + return ScaledBasis{r}; +} + +// Addition +template +CUTE_HOST_DEVICE constexpr +auto +operator+(ScaledBasis const& t, ScaledBasis const& u) { + return as_arithmetic_tuple(t) + as_arithmetic_tuple(u); +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator+(ScaledBasis const& t, ArithmeticTuple const& u) { + return as_arithmetic_tuple(t) + u; +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator+(ArithmeticTuple const& t, ScaledBasis const& u) { + return t + as_arithmetic_tuple(u); +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator+(C, ScaledBasis const& u) { + static_assert(t == 0, "ScaledBasis op+ error!"); + return u; +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator+(ScaledBasis const& t, C) { + static_assert(u == 0, "ScaledBasis op+ error!"); + return t; +} + +// +// Display utilities +// + +template +CUTE_HOST_DEVICE void print(ArithmeticTupleIterator const& iter) +{ + printf("ArithTuple"); print(iter.coord_); +} + +template +CUTE_HOST_DEVICE void print(ScaledBasis const& e) +{ + print(e.value()); printf("@%d", N); +} + +#if !defined(__CUDACC_RTC__) +template +CUTE_HOST std::ostream& operator<<(std::ostream& os, ArithmeticTupleIterator const& iter) +{ + return os << "ArithTuple" << iter.coord_; +} + +template +CUTE_HOST std::ostream& operator<<(std::ostream& os, ScaledBasis const& e) +{ + return os << e.value() << "@" << N; +} +#endif + +} // end namespace cute + + +namespace CUTE_STL_NAMESPACE +{ + +template +struct tuple_size> + : CUTE_STL_NAMESPACE::integral_constant +{}; + +template +struct tuple_element> + : CUTE_STL_NAMESPACE::tuple_element> +{}; + +template +struct tuple_size> + : CUTE_STL_NAMESPACE::integral_constant +{}; + +template +struct tuple_element> + : CUTE_STL_NAMESPACE::tuple_element> +{}; + +} // end namespace CUTE_STL_NAMESPACE + +#ifdef CUTE_STL_NAMESPACE_IS_CUDA_STD +namespace std +{ + +#if defined(__CUDACC_RTC__) +template +struct tuple_size; + +template +struct tuple_element; +#endif + +template +struct tuple_size> + : CUTE_STL_NAMESPACE::integral_constant +{}; + +template +struct tuple_element> + : CUTE_STL_NAMESPACE::tuple_element> +{}; + +template +struct tuple_size> + : CUTE_STL_NAMESPACE::integral_constant +{}; + +template +struct tuple_element> + : CUTE_STL_NAMESPACE::tuple_element> +{}; + +} // end namespace std +#endif // CUTE_STL_NAMESPACE_IS_CUDA_STD diff --git a/include/cute/numeric/complex.hpp b/include/cute/numeric/complex.hpp new file mode 100644 index 0000000000..7dd9ea5bf0 --- /dev/null +++ b/include/cute/numeric/complex.hpp @@ -0,0 +1,76 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTE_HOST_DEVICE + +#include // cutlass::complexm, cutlass::real, cutlass::imag, cutlass::is_complex + +namespace cute +{ + +using cutlass::complex; +using cutlass::is_complex; +using cutlass::RealType; +using cutlass::real; +using cutlass::imag; +using cutlass::conj; + +template +static constexpr auto is_complex_v = is_complex::value; + +/// Fused multiply-add for complex numbers +template +CUTE_HOST_DEVICE constexpr +void +fma(complex & d, + complex const& a, + complex const& b, + complex const& c) +{ + fma(d.real(), a.real(), b.real(), c.real()); + fma(d.imag(), a.real(), b.imag(), c.imag()); + fma(d.real(), -a.imag(), b.imag(), d.real()); + fma(d.imag(), a.imag(), b.real(), d.imag()); +} + +/// Fused multiply-add for triplets +template +CUTE_HOST_DEVICE constexpr +void +fma(complex const& a, + complex const& b, + complex & c) +{ + return fma(c, a, b, c); +} + +} // end namespace cute diff --git a/include/cute/numeric/int.hpp b/include/cute/numeric/int.hpp new file mode 100644 index 0000000000..571b3e3ed0 --- /dev/null +++ b/include/cute/numeric/int.hpp @@ -0,0 +1,106 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#if defined(__CUDACC_RTC__) +#include +#else +#include +#endif + +#include // CUTE_STL_NAMESPACE + +#include // cutlass::int2b_t, cutlass::int4b_t + +namespace cute +{ + +// +// Signed integers +// + +using int2_t = cutlass::int2b_t; +using int4_t = cutlass::int4b_t; +using CUTE_STL_NAMESPACE::int8_t; +using CUTE_STL_NAMESPACE::int16_t; +using CUTE_STL_NAMESPACE::int32_t; +using CUTE_STL_NAMESPACE::int64_t; + +template struct int_bit; +template <> struct int_bit< 2> { using type = int2_t; }; +template <> struct int_bit< 4> { using type = int4_t; }; +template <> struct int_bit< 8> { using type = int8_t; }; +template <> struct int_bit< 16> { using type = int16_t; }; +template <> struct int_bit< 32> { using type = int32_t; }; +template <> struct int_bit< 64> { using type = int64_t; }; + +template +using int_bit_t = typename int_bit::type; + +template +using int_byte = int_bit<8*N>; + +template +using int_byte_t = typename int_byte::type; + +// +// Unsigned integers +// + +using uint1_t = cutlass::uint1b_t; +using uint2_t = cutlass::uint2b_t; +using uint4_t = cutlass::uint4b_t; +using CUTE_STL_NAMESPACE::uint8_t; +using CUTE_STL_NAMESPACE::uint16_t; +using CUTE_STL_NAMESPACE::uint32_t; +using CUTE_STL_NAMESPACE::uint64_t; +using cutlass::uint128_t; + +template struct uint_bit; +template <> struct uint_bit< 1> { using type = uint1_t; }; +template <> struct uint_bit< 2> { using type = uint2_t; }; +template <> struct uint_bit< 4> { using type = uint4_t; }; +template <> struct uint_bit< 8> { using type = uint8_t; }; +template <> struct uint_bit< 16> { using type = uint16_t; }; +template <> struct uint_bit< 32> { using type = uint32_t; }; +template <> struct uint_bit< 64> { using type = uint64_t; }; +template <> struct uint_bit<128> { using type = cutlass::uint128_t; }; + +template +using uint_bit_t = typename uint_bit::type; + +template +using uint_byte = uint_bit<8*N>; + +template +using uint_byte_t = typename uint_byte::type; + +} // namespace cute diff --git a/include/cute/numeric/integer_sequence.hpp b/include/cute/numeric/integer_sequence.hpp new file mode 100644 index 0000000000..6080179585 --- /dev/null +++ b/include/cute/numeric/integer_sequence.hpp @@ -0,0 +1,151 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include +#include +#include + +namespace cute +{ + +using CUTE_STL_NAMESPACE::integer_sequence; +using CUTE_STL_NAMESPACE::make_integer_sequence; + +namespace detail { + +template +struct range_impl; + +template +struct range_impl, Begin> { + using type = integer_sequence; +}; + +template +struct reverse_impl; + +template +struct reverse_impl> { + using type = integer_sequence; +}; + +} // end namespace detail + +template +using make_integer_range = typename detail::range_impl< + T, + make_integer_sequence 0) ? (End-Begin) : 0>, + Begin>::type; + +template +using make_integer_sequence_reverse = typename detail::reverse_impl< + make_integer_sequence>::type; + +// +// Common aliases +// + +// int_sequence + +template +using int_sequence = integer_sequence; + +template +using make_int_sequence = make_integer_sequence; + +template +using make_int_rsequence = make_integer_sequence_reverse; + +template +using make_int_range = make_integer_range; + +// index_sequence + +template +using index_sequence = integer_sequence; + +template +using make_index_sequence = make_integer_sequence; + +template +using make_index_rsequence = make_integer_sequence_reverse; + +template +using make_index_range = make_integer_range; + +// +// Shortcuts +// + +template +using seq = int_sequence; + +template +using make_seq = make_int_sequence; + +template +using make_rseq = make_int_rsequence; + +template +using make_range = make_int_range; + +template +using tuple_seq = make_seq>::value>; + +template +using tuple_rseq = make_rseq>::value>; + +// +// Specialize cute::tuple-traits for std::integer_sequence +// + +template +struct tuple_size> + : cute::integral_constant +{}; + +template +struct tuple_element> +{ + constexpr static T idx[sizeof...(Is)] = {Is...}; + using type = cute::integral_constant; +}; + +template +CUTE_HOST_DEVICE constexpr +tuple_element_t> +get(integer_sequence) { + static_assert(I < sizeof...(Ints), "Index out of range"); + return {}; +} + +} // end namespace cute diff --git a/include/cute/numeric/integral_constant.hpp b/include/cute/numeric/integral_constant.hpp new file mode 100644 index 0000000000..3a8d036eef --- /dev/null +++ b/include/cute/numeric/integral_constant.hpp @@ -0,0 +1,517 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // cute::max, etc +#include // cute::print +#include // __CUTE_REQUIRES, cute::is_std_integral + +namespace cute +{ + +// A constant value: short name and type-deduction for fast compilation +template +struct C { + using type = C; + static constexpr auto value = v; + using value_type = decltype(v); + CUTE_HOST_DEVICE constexpr operator value_type() const noexcept { return value; } + CUTE_HOST_DEVICE constexpr value_type operator()() const noexcept { return value; } +}; + +// Deprecate +template +using constant = C; + +template +using bool_constant = C; + +using true_type = bool_constant; +using false_type = bool_constant; + +// A more std:: conforming integral_constant that enforces type but interops with C +template +struct integral_constant : C { + using type = integral_constant; + static constexpr T value = v; + using value_type = T; + // Disambiguate C::operator value_type() + //CUTE_HOST_DEVICE constexpr operator value_type() const noexcept { return value; } + CUTE_HOST_DEVICE constexpr value_type operator()() const noexcept { return value; } +}; + +// +// Traits +// + +// Use cute::is_std_integral to match built-in integral types (int, int64_t, unsigned, etc) +// Use cute::is_integral to match both built-in integral types AND static integral types. + +template +struct is_integral : bool_constant::value> {}; +template +struct is_integral > : true_type {}; +template +struct is_integral> : true_type {}; + +// Register FastDivmod as the integral type +template<> +struct is_integral : true_type {}; + +// is_static detects if an (abstract) value is defined completely by its type (no members) +template +struct is_static : bool_constant>::value> {}; + +template +constexpr bool is_static_v = is_static::value; + +// is_constant detects if a type is a static integral type and if v is equal to a value + +template +struct is_constant : false_type {}; +template +struct is_constant : is_constant {}; +template +struct is_constant : is_constant {}; +template +struct is_constant : is_constant {}; +template +struct is_constant : is_constant {}; +template +struct is_constant > : bool_constant {}; +template +struct is_constant> : bool_constant {}; + +// +// Specializations +// + +template +using Int = C; + +using _m32 = Int<-32>; +using _m24 = Int<-24>; +using _m16 = Int<-16>; +using _m12 = Int<-12>; +using _m10 = Int<-10>; +using _m9 = Int<-9>; +using _m8 = Int<-8>; +using _m7 = Int<-7>; +using _m6 = Int<-6>; +using _m5 = Int<-5>; +using _m4 = Int<-4>; +using _m3 = Int<-3>; +using _m2 = Int<-2>; +using _m1 = Int<-1>; +using _0 = Int<0>; +using _1 = Int<1>; +using _2 = Int<2>; +using _3 = Int<3>; +using _4 = Int<4>; +using _5 = Int<5>; +using _6 = Int<6>; +using _7 = Int<7>; +using _8 = Int<8>; +using _9 = Int<9>; +using _10 = Int<10>; +using _12 = Int<12>; +using _16 = Int<16>; +using _24 = Int<24>; +using _32 = Int<32>; +using _40 = Int<40>; +using _48 = Int<48>; +using _56 = Int<56>; +using _64 = Int<64>; +using _72 = Int<72>; +using _80 = Int<80>; +using _88 = Int<88>; +using _96 = Int<96>; +using _104 = Int<104>; +using _112 = Int<112>; +using _120 = Int<120>; +using _128 = Int<128>; +using _136 = Int<136>; +using _144 = Int<144>; +using _152 = Int<152>; +using _160 = Int<160>; +using _168 = Int<168>; +using _176 = Int<176>; +using _184 = Int<184>; +using _192 = Int<192>; +using _200 = Int<200>; +using _208 = Int<208>; +using _216 = Int<216>; +using _224 = Int<224>; +using _232 = Int<232>; +using _240 = Int<240>; +using _248 = Int<248>; +using _256 = Int<256>; +using _384 = Int<384>; +using _512 = Int<512>; +using _768 = Int<768>; +using _1024 = Int<1024>; +using _2048 = Int<2048>; +using _4096 = Int<4096>; +using _8192 = Int<8192>; +using _16384 = Int<16384>; +using _32768 = Int<32768>; +using _65536 = Int<65536>; +using _131072 = Int<131072>; +using _262144 = Int<262144>; +using _524288 = Int<524288>; + +/***************/ +/** Operators **/ +/***************/ + +#define CUTE_LEFT_UNARY_OP(OP) \ + template \ + CUTE_HOST_DEVICE constexpr \ + C<(OP t)> operator OP (C) { \ + return {}; \ + } +#define CUTE_RIGHT_UNARY_OP(OP) \ + template \ + CUTE_HOST_DEVICE constexpr \ + C<(t OP)> operator OP (C) { \ + return {}; \ + } +#define CUTE_BINARY_OP(OP) \ + template \ + CUTE_HOST_DEVICE constexpr \ + C<(t OP u)> operator OP (C, C) { \ + return {}; \ + } + +CUTE_LEFT_UNARY_OP(+); +CUTE_LEFT_UNARY_OP(-); +CUTE_LEFT_UNARY_OP(~); +CUTE_LEFT_UNARY_OP(!); +CUTE_LEFT_UNARY_OP(*); + +CUTE_BINARY_OP( +); +CUTE_BINARY_OP( -); +CUTE_BINARY_OP( *); +CUTE_BINARY_OP( /); +CUTE_BINARY_OP( %); +CUTE_BINARY_OP( &); +CUTE_BINARY_OP( |); +CUTE_BINARY_OP( ^); +CUTE_BINARY_OP(<<); +CUTE_BINARY_OP(>>); + +CUTE_BINARY_OP(&&); +CUTE_BINARY_OP(||); + +CUTE_BINARY_OP(==); +CUTE_BINARY_OP(!=); +CUTE_BINARY_OP( >); +CUTE_BINARY_OP( <); +CUTE_BINARY_OP(>=); +CUTE_BINARY_OP(<=); + +#undef CUTE_BINARY_OP +#undef CUTE_LEFT_UNARY_OP +#undef CUTE_RIGHT_UNARY_OP + +// +// Mixed static-dynamic special cases +// + +template ::value && t == 0)> +CUTE_HOST_DEVICE constexpr +C<0> +operator*(C, U) { + return {}; +} + +template ::value && t == 0)> +CUTE_HOST_DEVICE constexpr +C<0> +operator*(U, C) { + return {}; +} + +template ::value && t == 0)> +CUTE_HOST_DEVICE constexpr +C<0> +operator/(C, U) { + return {}; +} + +template ::value && (t == 1 || t == -1))> +CUTE_HOST_DEVICE constexpr +C<0> +operator%(U, C) { + return {}; +} + +template ::value && t == 0)> +CUTE_HOST_DEVICE constexpr +C<0> +operator%(C, U) { + return {}; +} + +template ::value && t == 0)> +CUTE_HOST_DEVICE constexpr +C<0> +operator&(C, U) { + return {}; +} + +template ::value && t == 0)> +CUTE_HOST_DEVICE constexpr +C<0> +operator&(U, C) { + return {}; +} + +template ::value && !bool(t))> +CUTE_HOST_DEVICE constexpr +C +operator&&(C, U) { + return {}; +} + +template ::value && !bool(t))> +CUTE_HOST_DEVICE constexpr +C +operator&&(U, C) { + return {}; +} + +template ::value && bool(t))> +CUTE_HOST_DEVICE constexpr +C +operator||(C, U) { + return {}; +} + +template ::value && bool(t))> +CUTE_HOST_DEVICE constexpr +C +operator||(U, C) { + return {}; +} + +// +// Named functions from math.hpp +// + +#define CUTE_NAMED_UNARY_FN(OP) \ + template \ + CUTE_HOST_DEVICE constexpr \ + C OP (C) { \ + return {}; \ + } +#define CUTE_NAMED_BINARY_FN(OP) \ + template \ + CUTE_HOST_DEVICE constexpr \ + C OP (C, C) { \ + return {}; \ + } \ + template ::value)> \ + CUTE_HOST_DEVICE constexpr \ + auto OP (C, U u) { \ + return OP(t,u); \ + } \ + template ::value)> \ + CUTE_HOST_DEVICE constexpr \ + auto OP (T t, C) { \ + return OP(t,u); \ + } + +CUTE_NAMED_UNARY_FN(abs); +CUTE_NAMED_UNARY_FN(signum); +CUTE_NAMED_UNARY_FN(has_single_bit); + +CUTE_NAMED_BINARY_FN(max); +CUTE_NAMED_BINARY_FN(min); +CUTE_NAMED_BINARY_FN(shiftl); +CUTE_NAMED_BINARY_FN(shiftr); +CUTE_NAMED_BINARY_FN(gcd); +CUTE_NAMED_BINARY_FN(lcm); + +#undef CUTE_NAMED_UNARY_FN +#undef CUTE_NAMED_BINARY_FN + +// +// Other functions +// + +template +CUTE_HOST_DEVICE constexpr +C +safe_div(C, C) { + static_assert(t % u == 0, "Static safe_div requires t % u == 0"); + return {}; +} + +template ::value)> +CUTE_HOST_DEVICE constexpr +auto +safe_div(C, U u) { + return t / u; +} + +template ::value)> +CUTE_HOST_DEVICE constexpr +auto +safe_div(T t, C) { + return t / u; +} + +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +conditional_return(true_type, TrueType&& t, FalseType&&) { + return static_cast(t); +} + +template +CUTE_HOST_DEVICE constexpr +decltype(auto) +conditional_return(false_type, TrueType&&, FalseType&& f) { + return static_cast(f); +} + +template +CUTE_HOST_DEVICE constexpr +auto +conditional_return(bool b, C const&, C const&) { + return C{}; +} + +template +CUTE_HOST_DEVICE constexpr +auto +conditional_return(bool b, C const&, C const&) { + return b ? v : u; +} + +// TrueType and FalseType must have a common type +template +CUTE_HOST_DEVICE constexpr +auto +conditional_return(bool b, TrueType const& t, FalseType const& f) { + return b ? t : f; +} + +// TrueType and FalseType don't require a common type +template +CUTE_HOST_DEVICE constexpr +auto +conditional_return(TrueType const& t, FalseType const& f) { + if constexpr (b) { + return t; + } else { + return f; + } +} + +template +CUTE_HOST_DEVICE constexpr +auto +static_value() +{ + if constexpr (is_std_integral::value) { + return Int{}; + } else { + return Trait::value; + } + CUTE_GCC_UNREACHABLE; +} + +// +// Display utilities +// + +template +CUTE_HOST_DEVICE void print(C) { + printf("_"); + ::cute::print(Value); +} + +#if !defined(__CUDACC_RTC__) +template +CUTE_HOST std::ostream& operator<<(std::ostream& os, C const&) { + return os << "_" << t; +} +#endif + + +namespace detail { + +// parse_int_digits takes a variadic number of digits and converts them into an int +template +constexpr uint64_t parse_int_digits(uint64_t result, int digit, Ts... digits) +{ + if constexpr (sizeof...(Ts) == 0) { + return 10 * result + digit; + } else { + return parse_int_digits(10 * result + digit, digits...); + } +} + +} // end namespace detail + + +// This user-defined literal operator allows cute::constant written as literals. For example, +// +// auto var = 32_c; +// +// var has type cute::constant. +// +template +constexpr cute::constant operator "" _c() +{ + static_assert((('0' <= digits && digits <= '9') && ...), + "Expected 0 <= digit <= 9 for each digit of the integer."); + return {}; +} + +} // end namespace cute diff --git a/include/cute/numeric/integral_ratio.hpp b/include/cute/numeric/integral_ratio.hpp new file mode 100644 index 0000000000..a614bdb2d9 --- /dev/null +++ b/include/cute/numeric/integral_ratio.hpp @@ -0,0 +1,293 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTE_HOST_DEVICE +#include // cute::false_type, cute::true_type +#include // cute::signum +#include // __CUTE_REQUIRES + +namespace cute +{ + +/** Compile-time rational arithmetic type. + * Like cute::C for std::integral_constant, cute::R for std::ratio has a short name + * for error messages and compile times. + * The static data members @a num and @a den represent the reduced numerator and denominator + * of the rational value. Thus, two cute::R types with different @a n or @a d are distinct types + * even if they represent the same rational value. + * A cute::R exposes the reduced canonical type via its ::type member. + * That is, cute::R<3,6>::type is cute::R<1,2> and cute::R<6,3>::type is cute::C<2>. + * A cute::R::value can be used much like any other trait::value. It can be involved in + * arithmetic expressions (according to the operator-overloads for cute::C and cute::R, + * though these may be incomplete) but with a potential rational value rather than an integral value. + */ +template +class R { + static_assert(d != 0); + static constexpr auto an = abs(n); + static constexpr auto ad = abs(d); + static constexpr auto g = gcd(an, ad); + + public: + static constexpr auto num = signum(n) * signum(d) * an / g; + static constexpr auto den = ad / g; + // RI: den >= 1 && gcd(abs(num),den) == 1 + using type = typename conditional, R>::type; +}; + +template +struct is_ratio : false_type {}; +template +struct is_ratio> : true_type {}; + +template +CUTE_HOST_DEVICE constexpr +typename R::type +ratio(C, C) { + return {}; +} + +template +CUTE_HOST_DEVICE constexpr +typename R::type +ratio(C, R) { + return {}; +} + +template +CUTE_HOST_DEVICE constexpr +typename R::type +ratio(R, C) { + return {}; +} + +template +CUTE_HOST_DEVICE constexpr +typename R::type +ratio(R, R) { + return {}; +} + +// +// Non-reduced ratio implementations +// + +template +CUTE_HOST_DEVICE constexpr +R +nratio(C, C) { + return {}; +} + +template +CUTE_HOST_DEVICE constexpr +R +nratio(C, R) { + return {}; +} + +template +CUTE_HOST_DEVICE constexpr +R +nratio(R, C) { + return {}; +} + +template +CUTE_HOST_DEVICE constexpr +R +nratio(R, R) { + return {}; +} + +// +// Operators +// + +template +CUTE_HOST_DEVICE constexpr +typename R::type +operator*(R, R) { + return {}; +} + +template +CUTE_HOST_DEVICE constexpr +typename R::type +operator*(R, C) { + return {}; +} + +template +CUTE_HOST_DEVICE constexpr +typename R::type +operator*(C, R) { + return {}; +} + +// Product with dynamic type needs to produce an integer... +template ::value)> +CUTE_HOST_DEVICE constexpr +auto +operator*(C const& c, R) { + return c * R::num / R::den; +} + +// Product with dynamic type needs to produce an integer... +template ::value)> +CUTE_HOST_DEVICE constexpr +auto +operator*(R, C const& c) { + return c * R::num / R::den; +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator/(C const& c, R) { + return c * R{}; +} + +template +CUTE_HOST_DEVICE constexpr +typename R::type +operator+(R, R) { + return {}; +} + +template +CUTE_HOST_DEVICE constexpr +typename R::type +operator+(R, C) { + return {}; +} + +template +CUTE_HOST_DEVICE constexpr +typename R::type +operator+(C, R) { + return {}; +} + +///////////////// +// Comparisons // +///////////////// + +template +CUTE_HOST_DEVICE constexpr +bool_constant::num == R::num && R::den == R::den> +operator==(R, R) { + return {}; +} + +template +CUTE_HOST_DEVICE constexpr +bool_constant::num == c && R::den == 1> +operator==(R, C) { + return {}; +} + +template +CUTE_HOST_DEVICE constexpr +bool_constant::num == c && R::den == 1> +operator==(C, R) { + return {}; +} + +/////////////////////// +// Special functions // +/////////////////////// + +template +CUTE_HOST_DEVICE constexpr +typename R::type +gcd(R, R) { + return {}; +} + +template +CUTE_HOST_DEVICE constexpr +typename R::type +gcd(R, C) { + return {}; +} + +template +CUTE_HOST_DEVICE constexpr +typename R::type +gcd(C, R) { + return {}; +} + +template +CUTE_HOST_DEVICE constexpr +typename R::type +abs(R) { + return {}; +} + +template +CUTE_HOST_DEVICE constexpr +int32_t +log_2(R) { + static_assert(R::num > 0); + static_assert(R::den > 0); + return log_2(static_cast(R::num)) - log_2(static_cast(R::den)); +} + +// @return A non-reduced ratio cute::R of the Trait0::value / Trait1::value +template +CUTE_HOST_DEVICE constexpr +auto +trait_ratio(Trait0, Trait1) { + return nratio(static_value(), static_value()); +} + +// +// Display utilities +// + +template +CUTE_HOST_DEVICE void print(R) { + print(C{}); print("/"); print(C{}); +} + +#if !defined(__CUDACC_RTC__) +template +CUTE_HOST std::ostream& operator<<(std::ostream& os, R) { + return os << "_" << C{} << "/" << C{}; +} +#endif + +} // end namespace cute diff --git a/include/cute/numeric/math.hpp b/include/cute/numeric/math.hpp new file mode 100644 index 0000000000..e493a3a953 --- /dev/null +++ b/include/cute/numeric/math.hpp @@ -0,0 +1,356 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTE_HOST_DEVICE +#include // __CUTE_REQUIRES + +#include + +namespace cute +{ + +// +// Common Operations +// + +template ::value && + is_arithmetic::value)> +CUTE_HOST_DEVICE constexpr +auto +max(T const& t, U const& u) { + return t < u ? u : t; +} + +template ::value && + is_arithmetic::value)> +CUTE_HOST_DEVICE constexpr +auto +min(T const& t, U const& u) { + return t < u ? t : u; +} + +template ::value)> +CUTE_HOST_DEVICE constexpr +auto +abs(T const& t) { + if constexpr (is_signed::value) { + return t < T(0) ? -t : t; + } else { + return t; + } + + CUTE_GCC_UNREACHABLE; +} + +// Returns 1 if x > 0, -1 if x < 0, and 0 if x is zero. +template ::value)> +CUTE_HOST_DEVICE constexpr +int +signum(T const& x) { + if constexpr (is_signed::value) { + return (T(0) < x) - (x < T(0)); + } else { + return T(0) < x; + } + + CUTE_GCC_UNREACHABLE; +} + +// +// C++17 operations +// + +// Greatest common divisor of two positive integers +template ::value && + is_std_integral::value)> +CUTE_HOST_DEVICE constexpr +cute::common_type_t +gcd(T t, U u) { + while (true) { + if (t == 0) { return u; } + u %= t; + if (u == 0) { return t; } + t %= u; + } +} + +// Least common multiple of two positive integers +template ::value && + is_std_integral::value)> +CUTE_HOST_DEVICE constexpr +cute::common_type_t +lcm(T const& t, U const& u) { + return (t / gcd(t,u)) * u; +} + +// +// C++20 operations +// + +// Checks if a number is an integral power of two +template +CUTE_HOST_DEVICE constexpr +bool +has_single_bit(T x) { + return x != 0 && (x & (x - 1)) == 0; +} + +// Smallest number of bits needed to represent the given value +// For x == 0, this is 0 +// For x != 0, this is 1 + floor(log2(x)) +// bit_width( 0b0000 ) = 0 +// bit_width( 0b0001 ) = 1 +// bit_width( 0b0010 ) = 2 +// bit_width( 0b0011 ) = 2 +// bit_width( 0b0100 ) = 3 +// bit_width( 0b0101 ) = 3 +// bit_width( 0b0110 ) = 3 +// bit_width( 0b0111 ) = 3 +template +CUTE_HOST_DEVICE constexpr +int +bit_width(T x) { + static_assert(is_unsigned::value, "Only to be used for unsigned types."); + constexpr int N = (numeric_limits::digits == 64 ? 6 : + (numeric_limits::digits == 32 ? 5 : + (numeric_limits::digits == 16 ? 4 : + (numeric_limits::digits == 8 ? 3 : (assert(false),0))))); + T r = 0; + for (int i = N - 1; i >= 0; --i) { + T shift = (x > ((T(1) << (T(1) << i))-1)) << i; + x >>= shift; + r |= shift; + } + return r + (x != 0); +} + +// Smallest integral power of two not less than the given value +// bit_ceil( 0b00000000 ) = 0b00000001 +// bit_ceil( 0b00000001 ) = 0b00000001 +// bit_ceil( 0b00000010 ) = 0b00000010 +// bit_ceil( 0b00000011 ) = 0b00000100 +// bit_ceil( 0b00000100 ) = 0b00000100 +// bit_ceil( 0b00000101 ) = 0b00001000 +// bit_ceil( 0b00000110 ) = 0b00001000 +// bit_ceil( 0b00000111 ) = 0b00001000 +// bit_ceil( 0b00001000 ) = 0b00001000 +// bit_ceil( 0b00001001 ) = 0b00010000 +template +CUTE_HOST_DEVICE constexpr +T +bit_ceil(T x) { + return x == 0 ? T(1) : (T(1) << bit_width(x - 1)); +} + +// Largest integral power of two not greater than the given value +// bit_floor( 0b00000000 ) = 0b00000000 +// bit_floor( 0b00000001 ) = 0b00000001 +// bit_floor( 0b00000010 ) = 0b00000010 +// bit_floor( 0b00000011 ) = 0b00000010 +// bit_floor( 0b00000100 ) = 0b00000100 +// bit_floor( 0b00000101 ) = 0b00000100 +// bit_floor( 0b00000110 ) = 0b00000100 +// bit_floor( 0b00000111 ) = 0b00000100 +// bit_floor( 0b00001000 ) = 0b00001000 +// bit_floor( 0b00001001 ) = 0b00001000 +template +CUTE_HOST_DEVICE constexpr +T +bit_floor(T x) { + return x == 0 ? 0 : (T(1) << (bit_width(x) - 1)); +} + +template +CUTE_HOST_DEVICE constexpr T rotl(T x, int s); +template +CUTE_HOST_DEVICE constexpr T rotr(T x, int s); + +// Computes the result of circular bitwise left-rotation +template +CUTE_HOST_DEVICE constexpr +T +rotl(T x, int s) { + constexpr int N = numeric_limits::digits; + return static_cast(s == 0 ? x : s > 0 ? (x << s) | (x >> (N - s)) : rotr(x, -s)); +} + +// Computes the result of circular bitwise right-rotation +template +CUTE_HOST_DEVICE constexpr +T +rotr(T x, int s) { + constexpr int N = numeric_limits::digits; + return static_cast(s == 0 ? x : s > 0 ? (x >> s) | (x << (N - s)) : rotl(x, -s)); +} + +// Counts the number of consecutive 0 bits, starting from the most significant bit +// countl_zero( 0b00000000 ) = 8 +// countl_zero( 0b11111111 ) = 0 +// countl_zero( 0b00011100 ) = 3 +template +CUTE_HOST_DEVICE constexpr +int +countl_zero(T x) { + return numeric_limits::digits - bit_width(x); +} + +// Counts the number of consecutive 1 bits, starting from the most significant bit +// countl_one( 0b00000000 ) = 0 +// countl_one( 0b11111111 ) = 8 +// countl_one( 0b11100011 ) = 3 +template +CUTE_HOST_DEVICE constexpr +int +countl_one(T x) { + return countl_zero(~x); +} + +// Counts the number of consecutive 0 bits, starting from the least significant bit +// countr_zero( 0b00000000 ) = 8 +// countr_zero( 0b11111111 ) = 0 +// countr_zero( 0b00011100 ) = 2 +template +CUTE_HOST_DEVICE constexpr +int +countr_zero(T x) { + return x == 0 ? numeric_limits::digits : bit_width(T(x & T(-x))) - 1; // bit_width of the LSB +} + +// Counts the number of consecutive 1 bits, starting from the least significant bit +// countr_one( 0b00000000 ) = 0 +// countr_one( 0b11111111 ) = 8 +// countr_one( 0b11100011 ) = 2 +template +CUTE_HOST_DEVICE constexpr +int +countr_one(T x) { + return countr_zero(~x); +} + +// Counts the number of 1 bits in an unsigned integer +// popcount( 0b00000000 ) = 0 +// popcount( 0b11111111 ) = 8 +// popcount( 0b00011101 ) = 4 +template +CUTE_HOST_DEVICE constexpr +int +popcount(T x) { + int c = 0; + while (x) { + ++c; + x &= x - 1; // clear the least significant bit set + } + return c; +} + +// +// Custom operations +// + +// Computes the result of bitwise left-shift +template +CUTE_HOST_DEVICE constexpr +auto +shiftl(T x, int s) { + return s >= 0 ? (x << s) : (x >> -s); +} + +// Computes the result of bitwise right-shift +template +CUTE_HOST_DEVICE constexpr +auto +shiftr(T x, int s) { + return s >= 0 ? (x >> s) : (x << -s); +} + +// Safe divide +// @pre t % u == 0 +// @result t / u +template ::value && + is_std_integral::value)> +CUTE_HOST_DEVICE constexpr +auto +safe_div(T const& t, U const& u) { + //assert(t % u == 0); + return t / u; +} + +/** + * log2 computation + */ + +template +CUTE_HOST_DEVICE constexpr +int32_t +log_2(T x) { + assert(x > 0); + static_assert(is_unsigned::value, "Only to be used for unsigned integral types."); + return static_cast(bit_width(x)) - 1; +} + +template +struct DivModReturnType { + IntDiv div_; + IntMod mod_; + CUTE_HOST_DEVICE constexpr + DivModReturnType(IntDiv const& div, IntMod const& mod) : div_(div), mod_(mod) {} +}; + +// General divmod +template +CUTE_HOST_DEVICE constexpr +auto +divmod(CInt0 const& a, CInt1 const& b) { + return DivModReturnType{a / b, a % b}; +} + +// Specialized function with fastDivmod input +template +CUTE_HOST_DEVICE constexpr +auto +divmod(CInt const& a, cutlass::FastDivmod const& b) { + using val_div_type = typename cutlass::FastDivmod::value_div_type; + using val_mod_type = typename cutlass::FastDivmod::value_mod_type; + val_div_type div = 0; + val_mod_type mod = 0; + b(div, mod, a); + return DivModReturnType{div, mod}; +} + +} // namespace cute diff --git a/include/cute/numeric/numeric_types.hpp b/include/cute/numeric/numeric_types.hpp new file mode 100644 index 0000000000..b9943b8ca3 --- /dev/null +++ b/include/cute/numeric/numeric_types.hpp @@ -0,0 +1,136 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTE_HOST_DEVICE +#include // cute::int2_t, cute::int4_t, etc + +#include // cutlass::sizeof_bits +#include // cutlass::float_e4m3_t, cutlass::float_e5m2_t, etc + +namespace cute { + +template +struct sizeof_bits : public cutlass::sizeof_bits {}; + +// DO NOT change auto to int, sizeof_bits use integral_ratio instead of int +template +static constexpr auto sizeof_bits_v = sizeof_bits::value; + +using cutlass::bits_to_bytes; +using cutlass::bytes_to_bits; + +using cutlass::is_subbyte; + +template +static constexpr auto is_subbyte_v = is_subbyte::value; + +using cutlass::half_t; +using cutlass::bfloat16_t; + +using cutlass::tfloat32_t; + +// Umbrella floating-point 8-bit data type : type_erased_dynamic_float8_t +// This umbrella datatype can be enabled when a user provides a specific +// datatype in runtime argument list. +using cutlass::type_erased_dynamic_float8_t; +using cutlass::float_e4m3_t; +using cutlass::float_e5m2_t; + +using cutlass::uint1b_t; +using cutlass::int2b_t; +using cutlass::uint2b_t; +using cutlass::int4b_t; +using cutlass::uint4b_t; +using cutlass::bin1_t; + + +// +// Print utility +// + +CUTE_HOST_DEVICE +void +print(half_t a) { + printf("%f", static_cast(a)); +} + +CUTE_HOST_DEVICE +void +print(bfloat16_t a) { + printf("%f", static_cast(a)); +} + + +CUTE_HOST_DEVICE +void +print(tfloat32_t a) { + printf("%f", static_cast(a)); +} + +CUTE_HOST_DEVICE +void +print(float_e4m3_t a) { + printf("%f", static_cast(a)); +} + +CUTE_HOST_DEVICE +void +print(float_e5m2_t a) { + printf("%f", static_cast(a)); +} + +CUTE_HOST_DEVICE void +pretty_print(bfloat16_t v) { + printf("%*.2f", 8, float(v)); +} + +CUTE_HOST_DEVICE void +pretty_print(half_t v) { + printf("%*.2f", 8, float(v)); +} + +CUTE_HOST_DEVICE void +pretty_print(tfloat32_t v) { + printf("%*.2e", 10, static_cast(v)); +} + +CUTE_HOST_DEVICE void +pretty_print(float_e4m3_t t) { + printf("%*.2f", 8, static_cast(t)); +} + +CUTE_HOST_DEVICE void +pretty_print(float_e5m2_t t) { + printf("%*.2f", 8, static_cast(t)); +} + +} // namespace cute diff --git a/include/cute/numeric/real.hpp b/include/cute/numeric/real.hpp new file mode 100644 index 0000000000..4ce58dfa18 --- /dev/null +++ b/include/cute/numeric/real.hpp @@ -0,0 +1,74 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +namespace cute +{ + +/// Generic add +template +CUTE_HOST_DEVICE constexpr +void +add(C& c, A const& a, B const& b) +{ + c = a + b; +} + +/// Generic multiply +template +CUTE_HOST_DEVICE constexpr +void +mul(C& c, A const& a, B const& b) +{ + c = a * b; +} + +/// Generic fused multiply-add +template +CUTE_HOST_DEVICE constexpr +void +fma(D& d, A const& a, B const& b, C const& c) +{ + d = a * b + c; +} + +/// Fused multiply-add for triplets +template +CUTE_HOST_DEVICE constexpr +void +fma(A const& a, B const& b, C& c) +{ + return fma(c, a, b, c); +} + +} // end namespace cute diff --git a/include/cute/pointer.hpp b/include/cute/pointer.hpp new file mode 100644 index 0000000000..cc49b6a356 --- /dev/null +++ b/include/cute/pointer.hpp @@ -0,0 +1,330 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTE_HOST_DEVICE +#include // cute::iter_adaptor +#include +#include // cute::subbyte_iterator +#include // cute::true_type, cute::false_type +#include // sizeof_bits + +namespace cute +{ + +// +// recast_ptr -- Create an iterator over values of type T. +// For most types this will simply be T*, but certain types require more care. +// Subbyte Types: uint2_t, uint4_t, etc +// Requires construction of a subbyte_iterator in order to properly +// resolve each element in byte-addressed memory. +// Sparse Types: sparse_elem +// A type that holds one physical element meant to represent S number of logical elements. +// Requires construction of a sparse_ptr that emulates access to the S logical elements. +// + +template +CUTE_HOST_DEVICE constexpr +auto +recast_ptr(void* ptr) +{ + if constexpr (is_sparse::value) { + constexpr int sparsity = NewT::sparsity; + NewT* p = reinterpret_cast(ptr); + return make_sparse_ptr(p); + } else + if constexpr (cute::is_subbyte_v) { + return subbyte_iterator(ptr); + } else { + return reinterpret_cast(ptr); + } + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +recast_ptr(void const* ptr) +{ + if constexpr (is_sparse::value) { + constexpr int sparsity = NewT::sparsity; + NewT const* p = reinterpret_cast(ptr); + return make_sparse_ptr(p); + } else + if constexpr (cute::is_subbyte_v) { + return subbyte_iterator(ptr); + } else { + return reinterpret_cast(ptr); + } + CUTE_GCC_UNREACHABLE; +} + +// Disambiguate nullptr +template +CUTE_HOST_DEVICE constexpr +auto +recast_ptr(decltype(nullptr)) { // nullptr_t + return recast_ptr(static_cast(nullptr)); +} + +// +// gmem_ptr +// + +template +struct gmem_ptr : iter_adaptor> { + using iter_adaptor>::iter_adaptor; +}; + +template +struct is_gmem : false_type {}; +template // Found the gmem +struct is_gmem> : true_type {}; +template // Recurse on ::iterator, if possible +struct is_gmem> : is_gmem {}; +template +constexpr bool is_gmem_v = is_gmem

::value; + +// Idempotent gmem tag on an iterator +template +CUTE_HOST_DEVICE constexpr +auto +make_gmem_ptr(Iterator iter) { + if constexpr (is_gmem::value) { + return iter; + } else { + return gmem_ptr{iter}; + } + CUTE_GCC_UNREACHABLE; +} + +// Explicitly typed construction from a raw pointer +template +CUTE_HOST_DEVICE constexpr +auto +make_gmem_ptr(void* ptr) { + return make_gmem_ptr(recast_ptr(ptr)); +} + +// Explicitly typed construction from a raw pointer +template +CUTE_HOST_DEVICE constexpr +auto +make_gmem_ptr(void const* ptr) { + return make_gmem_ptr(recast_ptr(ptr)); +} + +// nullptr_t overload for make_gmem_ptr(nullptr) disambiguation +template +CUTE_HOST_DEVICE constexpr +auto +make_gmem_ptr(decltype(nullptr)) { // nullptr_t + return make_gmem_ptr(recast_ptr(nullptr)); +} + +// The gmem tag is invariant over type-recast +template +CUTE_HOST_DEVICE constexpr +auto +recast_ptr(gmem_ptr

const& ptr) { + return make_gmem_ptr(recast_ptr(ptr.get())); +} + +// +// smem_ptr +// + +template +struct smem_ptr : iter_adaptor> { + using iter_adaptor>::iter_adaptor; +}; + +template +struct is_smem : false_type {}; +template // Found the smem +struct is_smem> : true_type {}; +template // Recurse on ::iterator, if possible +struct is_smem> : is_smem {}; +template +constexpr bool is_smem_v = is_smem

::value; + +// Idempotent smem tag on an iterator +template +CUTE_HOST_DEVICE constexpr +auto +make_smem_ptr(Iterator iter) { + if constexpr (is_smem::value) { + return iter; + } else { + return smem_ptr{iter}; + } + CUTE_GCC_UNREACHABLE; +} + +// Make a smem swizzle pointer, common operation +template +CUTE_HOST_DEVICE constexpr +auto +make_smem_ptr(Iterator ptr, Swizzle sw) +{ + return make_swizzle_ptr(make_smem_ptr(ptr), sw); +} + +// Explicitly typed construction from a raw pointer +template +CUTE_HOST_DEVICE constexpr +auto +make_smem_ptr(void* ptr) { + return make_smem_ptr(recast_ptr(ptr)); +} + +// Explicitly typed construction from a raw pointer +template +CUTE_HOST_DEVICE constexpr +auto +make_smem_ptr(void const* ptr) { + return make_smem_ptr(recast_ptr(ptr)); +} + +// nullptr_t overload for make_smem_ptr(nullptr) disambiguation +template +CUTE_HOST_DEVICE constexpr +auto +make_smem_ptr(decltype(nullptr)) { // nullptr_t + return make_smem_ptr(recast_ptr(nullptr)); +} + +// The smem tag is invariant over type-recast +template +CUTE_HOST_DEVICE constexpr +auto +recast_ptr(smem_ptr

const& ptr) { + return make_smem_ptr(recast_ptr(ptr.get())); +} + +// +// rmem_ptr +// + +template +struct rmem_ptr : iter_adaptor> { + using iter_adaptor>::iter_adaptor; +}; + +// Anything that is not gmem or smem is rmem +template +struct is_rmem : bool_constant::value || is_smem::value)> {}; +template +struct is_rmem> : true_type {}; +template +constexpr bool is_rmem_v = is_rmem

::value; + +// Idempotent rmem tag on an iterator +template +CUTE_HOST_DEVICE constexpr +auto +make_rmem_ptr(Iterator iter) { + if constexpr (is_rmem::value) { + return iter; + } else { + return rmem_ptr{iter}; + } + CUTE_GCC_UNREACHABLE; +} + +// Explicitly typed construction from a raw pointer +template +CUTE_HOST_DEVICE constexpr +auto +make_rmem_ptr(void* ptr) { + return make_rmem_ptr(recast_ptr(ptr)); +} + +// Explicitly typed construction from a raw pointer +template +CUTE_HOST_DEVICE constexpr +auto +make_rmem_ptr(void const* ptr) { + return make_rmem_ptr(recast_ptr(ptr)); +} + +// The rmem tag is invariant over type-recast +template +CUTE_HOST_DEVICE constexpr +auto +recast_ptr(rmem_ptr

const& ptr) { + return make_rmem_ptr(recast_ptr(ptr.get())); +} + +// +// Display utilities +// + +template +CUTE_HOST_DEVICE void print(gmem_ptr ptr) +{ + printf("gmem_"); print(ptr.get()); +} + +template +CUTE_HOST_DEVICE void print(smem_ptr ptr) +{ + printf("smem_"); print(ptr.get()); +} + +template +CUTE_HOST_DEVICE void print(rmem_ptr ptr) +{ + printf("rmem_"); print(ptr.get()); +} + +#if !defined(__CUDACC_RTC__) +template +CUTE_HOST std::ostream& operator<<(std::ostream& os, gmem_ptr ptr) +{ + return os << "gmem_[" << int(sizeof_bits>::value) << "b]"; +} + +template +CUTE_HOST std::ostream& operator<<(std::ostream& os, smem_ptr ptr) +{ + return os << "smem_[" << int(sizeof_bits>::value) << "b]"; +} + +template +CUTE_HOST std::ostream& operator<<(std::ostream& os, rmem_ptr ptr) +{ + return os << "rmem_[" << int(sizeof_bits>::value) << "b]"; +} + +#endif // !defined(__CUDACC_RTC__) + +} // end namespace cute diff --git a/include/cute/pointer_base.hpp b/include/cute/pointer_base.hpp new file mode 100644 index 0000000000..57ad0b3cde --- /dev/null +++ b/include/cute/pointer_base.hpp @@ -0,0 +1,262 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTE_HOST_DEVICE +#include // cute::sizeof_bits +#include // Int<0> +#include // cute::declval, cute::void_t, etc + +namespace cute +{ + +// +// C++20 iterator_traits +// + +namespace detail { +// Default reference type of an iterator +template +struct iter_ref { using type = decltype(*declval()); }; +// Prefer to propagate ::reference +template +struct iter_ref> { using type = typename T::reference; }; +} // end namespace detail + +template +using iter_reference = detail::iter_ref; +template +using iter_reference_t = typename iter_reference::type; + +namespace detail { +// Default element_type of an iterator +template +struct iter_e { using type = remove_reference_t::type>; }; +// Prefer to propagate ::element_type +template +struct iter_e> { using type = typename T::element_type; }; +} // end namespace detail + +template +using iter_element = detail::iter_e; +template +using iter_element_t = typename iter_element::type; + +namespace detail { +// Default value_type of an iterator +template +struct iter_v { using type = remove_cv_t::type>; }; +// Prefer to propagate ::value_type +template +struct iter_v> { using type = typename T::value_type; }; +} // end namespace detail + +template +using iter_value = detail::iter_v; +template +using iter_value_t = typename iter_value::type; + +template +struct iterator_traits { + using reference = iter_reference_t; + using element_type = iter_element_t; + using value_type = iter_value_t; +}; + +// +// has_dereference to determine if a type is an iterator concept +// + +namespace detail { +template +struct has_dereference : CUTE_STL_NAMESPACE::false_type {}; +template +struct has_dereference())>> : CUTE_STL_NAMESPACE::true_type {}; +} // end namespace detail + +template +using has_dereference = detail::has_dereference; + +// +// raw_pointer_cast +// + +template +CUTE_HOST_DEVICE constexpr +T* +raw_pointer_cast(T* ptr) { + return ptr; +} + +// The statically-known alignment of a dynamic pointer is unknown +template +CUTE_HOST_DEVICE constexpr +Int<0> +max_alignment(T*) { + return {}; +} + +// +// A very simplified iterator adaptor. +// Derived classed may override methods, but be careful to reproduce interfaces exactly. +// Clients should never have an instance of this class. Do not write methods that take this as a param. +// + +template +struct iter_adaptor +{ + using iterator = Iterator; + using reference = typename iterator_traits::reference; + using element_type = typename iterator_traits::element_type; + using value_type = typename iterator_traits::value_type; + + iterator ptr_; + + CUTE_HOST_DEVICE constexpr + iter_adaptor(iterator ptr = {}) : ptr_(ptr) {} + + CUTE_HOST_DEVICE constexpr + reference operator*() const { return *ptr_; } + + template + CUTE_HOST_DEVICE constexpr + reference operator[](Index const& i) const { return ptr_[i]; } + + template + CUTE_HOST_DEVICE constexpr + DerivedType operator+(Index const& i) const { return {ptr_ + i}; } + + CUTE_HOST_DEVICE constexpr + iterator get() const { return ptr_; } + + CUTE_HOST_DEVICE constexpr + friend bool operator==(DerivedType const& x, DerivedType const& y) { return x.ptr_ == y.ptr_; } + CUTE_HOST_DEVICE constexpr + friend bool operator!=(DerivedType const& x, DerivedType const& y) { return x.ptr_ != y.ptr_; } + CUTE_HOST_DEVICE constexpr + friend bool operator< (DerivedType const& x, DerivedType const& y) { return x.ptr_ < y.ptr_; } + CUTE_HOST_DEVICE constexpr + friend bool operator<=(DerivedType const& x, DerivedType const& y) { return x.ptr_ <= y.ptr_; } + CUTE_HOST_DEVICE constexpr + friend bool operator> (DerivedType const& x, DerivedType const& y) { return x.ptr_ > y.ptr_; } + CUTE_HOST_DEVICE constexpr + friend bool operator>=(DerivedType const& x, DerivedType const& y) { return x.ptr_ >= y.ptr_; } +}; + +template +CUTE_HOST_DEVICE constexpr +auto +raw_pointer_cast(iter_adaptor const& x) { + return raw_pointer_cast(x.ptr_); +} + +template +CUTE_HOST_DEVICE constexpr +auto +max_alignment(iter_adaptor const& x) { + return max_alignment(x.ptr_); +} + +// +// counting iterator -- quick and dirty +// + +template +struct counting_iterator +{ + using index_type = T; + using value_type = T; + using reference = T; + + index_type n_; + + CUTE_HOST_DEVICE constexpr + counting_iterator(index_type n = 0) : n_(n) {} + + CUTE_HOST_DEVICE constexpr + index_type operator*() const { return n_; } + + CUTE_HOST_DEVICE constexpr + index_type operator[](index_type i) const { return n_ + i; } + + CUTE_HOST_DEVICE constexpr + counting_iterator operator+(index_type i) const { return {n_ + i}; } + CUTE_HOST_DEVICE constexpr + counting_iterator& operator++() { ++n_; return *this; } + CUTE_HOST_DEVICE constexpr + counting_iterator operator++(int) { counting_iterator ret = *this; ++n_; return ret; } + + CUTE_HOST_DEVICE constexpr + friend bool operator==(counting_iterator const& x, counting_iterator const& y) { return x.n_ == y.n_; } + CUTE_HOST_DEVICE constexpr + friend bool operator!=(counting_iterator const& x, counting_iterator const& y) { return x.n_ != y.n_; } + CUTE_HOST_DEVICE constexpr + friend bool operator< (counting_iterator const& x, counting_iterator const& y) { return x.n_ < y.n_; } + CUTE_HOST_DEVICE constexpr + friend bool operator<=(counting_iterator const& x, counting_iterator const& y) { return x.n_ <= y.n_; } + CUTE_HOST_DEVICE constexpr + friend bool operator> (counting_iterator const& x, counting_iterator const& y) { return x.n_ > y.n_; } + CUTE_HOST_DEVICE constexpr + friend bool operator>=(counting_iterator const& x, counting_iterator const& y) { return x.n_ >= y.n_; } +}; + +template +CUTE_HOST_DEVICE constexpr +T +raw_pointer_cast(counting_iterator const& x) { + return x.n_; +} + +// +// Display utilities +// + +template +CUTE_HOST_DEVICE void print(T const* const ptr) +{ + printf("ptr["); print(sizeof_bits::value); printf("b](%p)", ptr); +} + +template +CUTE_HOST_DEVICE void print(counting_iterator ptr) +{ + printf("counting_iter("); print(ptr.n_); printf(")"); +} + +#if !defined(__CUDACC_RTC__) +template +CUTE_HOST std::ostream& operator<<(std::ostream& os, counting_iterator ptr) +{ + return os << "counting_iter(" << ptr.n_ << ")"; +} +#endif // !defined(__CUDACC_RTC__) + +} // end namespace cute diff --git a/include/cute/pointer_flagged.hpp b/include/cute/pointer_flagged.hpp new file mode 100644 index 0000000000..eb8d7e452e --- /dev/null +++ b/include/cute/pointer_flagged.hpp @@ -0,0 +1,199 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTE_HOST_DEVICE +#include // cute::ComposedLayout +#include // cute::make_smem_ptr +#include // cute::is_sparse +#include // cute::make_swizzle_ptr +#include // cute::cast_smem_ptr_to_uint +#include // cute::Int + +namespace cute +{ + +// +// Stand-in Swizzle Layout +// A model of a nullptr smem_ptr with B == sizeof_bits::value +// That represents an unset pointer. This is a placeholder type that is waiting for an smem_ptr +// + +template +struct smem_ptr_flag_bits : Int<0> {}; + +using smem_ptr_flag = smem_ptr_flag_bits<1>; + +// A flagged construction method to transform ComposedLayout +// Make a swizzle pointer tensor and check that the intended type size matches +template +CUTE_HOST_DEVICE constexpr +auto +make_tensor(Iterator const& ptr, + ComposedLayout,Layout> const& layout) +{ + static_assert(is_smem::value, "Expected smem."); + static_assert(B == sizeof_bits>::value, "Expected a B-bit pointer type."); + return make_tensor(make_smem_ptr(ptr.get(), layout.layout_a()), + layout.layout_b()); +} + +// NOTE: To preserve smem_ptr_flag_bits under recast ops +template +CUTE_HOST_DEVICE constexpr +auto +upcast(ComposedLayout,Layout> const& layout) +{ + return composition(layout.layout_a(), smem_ptr_flag_bits{}, upcast(layout.layout_b())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +downcast(ComposedLayout,Layout> const& layout) +{ + return composition(layout.layout_a(), smem_ptr_flag_bits{}, downcast(layout.layout_b())); +} + +// +// Conversion with swizzle_layout +// + +template +CUTE_HOST_DEVICE +auto +as_position_independent_swizzle_layout(ComposedLayout,Layout> const& layout) +{ + return composition(recast_layout>(layout.layout_a()), Int<0>{}, layout.layout_b()); +} + +template +CUTE_HOST_DEVICE +auto +as_position_independent_swizzle_tensor(Tensor&& tensor) +{ + static_assert(is_smem>::value, "Expected smem tensor."); + using SwizzleFn = get_swizzle_t>; + if constexpr (SwizzleFn::num_bits == 0) { + return tensor; + } else { +#if !defined(NDEBUG) + { + uint32_t address = cast_smem_ptr_to_uint(raw_pointer_cast(static_cast(tensor).data())); + uint32_t mask = ((uint32_t(1) << SwizzleFn::num_base) - 1) | SwizzleFn::swizzle_code; + assert((address & mask) == 0); // Alignment to the Base, Z, and Y of Swizzle + } +#endif + using T = typename remove_cvref_t::value_type; + // Recast swizzle from acting on byte-addressed pointers to elements of type-T + auto new_swizzle = recast_layout(SwizzleFn{}); + // Strip off everything and create a new smem_ptr for type-T + auto new_ptr = make_smem_ptr(raw_pointer_cast(static_cast(tensor).data())); + return make_tensor(new_ptr, composition(new_swizzle, Int<0>{}, tensor.layout())); + } + CUTE_GCC_UNREACHABLE; +} + +// A model of a nullptr sparse_ptr> with B == sizeof_bits::value +// That represents an unset pointer. This is a placeholder type that is waiting for an smem_ptr +template +struct smem_sparse_ptr_flag_bits : Int<0> {}; + +template +using smem_sparse_ptr_flag = smem_sparse_ptr_flag_bits; + +// A flagged construction method to transform ComposedLayout +// Make a swizzle pointer tensor and check that the intended type size matches +template +CUTE_HOST_DEVICE constexpr +auto +make_tensor(Iterator const& ptr, + ComposedLayout,Layout> const& layout) +{ + static_assert(is_smem::value, "Expected smem."); + static_assert(is_sparse_ptr::value, "Expected sparse iter"); + static_assert(is_sparse>::value, "Expected sparse elem"); + static_assert(S == iter_value_t::sparsity, "Expected sparsity S"); + static_assert(B == sizeof_bits::raw_type>::value, "Expected B-bit pointer type"); + return make_tensor(make_swizzle_ptr(ptr, layout.layout_a()), layout.layout_b()); +} + +// NOTE: To preserve smem_ptr_flag_bits under recast ops +template +CUTE_HOST_DEVICE constexpr +auto +upcast(ComposedLayout,Layout> const& layout) +{ + static_assert(dependent_false, "Not implemented for safety"); +} + +template +CUTE_HOST_DEVICE constexpr +auto +downcast(ComposedLayout,Layout> const& layout) +{ + static_assert(dependent_false, "Not implemented for safety"); +} + +// +// Display utilities +// + +// Capture and cast smem_ptr_flag Layouts to offset-0 layouts +template +CUTE_HOST_DEVICE +void +print_layout(ComposedLayout,Layout> const& layout) +{ + print_layout(as_position_independent_swizzle_layout(layout)); +} + +template +CUTE_HOST_DEVICE +void +print_latex(ComposedLayout,Layout> const& layout) +{ + print_latex(as_position_independent_swizzle_layout(layout)); +} + +template +CUTE_HOST_DEVICE void print(smem_ptr_flag_bits ptr) +{ + printf("smem_ptr[%db](unset)", B); +} + +template +CUTE_HOST_DEVICE void print(smem_sparse_ptr_flag_bits) +{ + printf("smem_sparse<%d>_ptr[%db](unset)", S, B); +} + +} // end namespace cute diff --git a/include/cute/pointer_sparse.hpp b/include/cute/pointer_sparse.hpp new file mode 100644 index 0000000000..ccae458650 --- /dev/null +++ b/include/cute/pointer_sparse.hpp @@ -0,0 +1,172 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include // CUTE_HOST_DEVICE +#include // cute::iter_adaptor +#include // cute::false_type, cute::true_type +#include // cute::ratio + +namespace cute +{ + +// A data type that holds one physical element meant to represent Sparsity number of logical elements +// This class is purposely not compatible with anything -- know what you're doing if you attempt to use it +template +struct sparse_elem +{ + static constexpr int sparsity = Sparsity; + using raw_type = T; + T elem_; + + CUTE_HOST_DEVICE constexpr + explicit sparse_elem(T const& elem = {}) : elem_(elem) {} + + CUTE_HOST_DEVICE constexpr friend bool operator==(sparse_elem const& a, sparse_elem const& b) { return a.elem_ == b.elem_; } + CUTE_HOST_DEVICE constexpr friend bool operator!=(sparse_elem const& a, sparse_elem const& b) { return a.elem_ != b.elem_; } + CUTE_HOST_DEVICE constexpr friend bool operator< (sparse_elem const& a, sparse_elem const& b) { return a.elem_ < b.elem_; } + CUTE_HOST_DEVICE constexpr friend bool operator<=(sparse_elem const& a, sparse_elem const& b) { return a.elem_ <= b.elem_; } + CUTE_HOST_DEVICE constexpr friend bool operator> (sparse_elem const& a, sparse_elem const& b) { return a.elem_ > b.elem_; } + CUTE_HOST_DEVICE constexpr friend bool operator>=(sparse_elem const& a, sparse_elem const& b) { return a.elem_ >= b.elem_; } +}; + +template +struct is_sparse : false_type {}; +template +struct is_sparse : is_sparse {}; +template +struct is_sparse> : true_type {}; +template +static constexpr auto is_sparse_v = is_sparse::value; + +// Overload sizeof_bits for sparse_elem. +// Much like subbyte element types, this is the effective number of bits in a sparse_elem +// rather than actual physical bits that may be used in storing one. Also like subbyte element +// types, modified iterators are required to properly index and access sparse_elems. +// +// Defining sizeof_bits like this makes reasonable expressions like N * sizeof_bits_v meaningful +// even when E is subbyte or sparse. However, this also means that sparse_elem can rather easily be +// confused with subbyte elements and special care should be taken with each. +template +struct sizeof_bits> { + // Simple implementation that conforms to sizeof_bits + //static constexpr auto value = sizeof_bits::value / S; + //static_assert(value != 0, "sizeof_bits=0 detected. Sparsity is larger than width."); + //static_assert((sizeof_bits::value % S) == 0, "Width needs to be a multiple of sparsity.") + + // Interesting experiment that allows any sparsity level to be used by potentially presenting + // an integral_ratio rather than size_t. This is valid in most integer expressions as well. + static constexpr auto value = cute::ratio(cute::Int>{}, cute::Int{}); +}; + +// +// sparse_ptr +// + +template +struct is_sparse_ptr : false_type {}; +template +struct is_sparse_ptr> : is_sparse_ptr {}; + +template +struct sparse_ptr : iter_adaptor> +{ + using reference = typename iterator_traits::reference; + using element_type = typename iterator_traits::element_type; + using value_type = typename iterator_traits::value_type; + + // Sanity, for now + static_assert(is_sparse::value, "Enforce sparse value-type"); + static_assert(Sparsity == iter_value_t::sparsity, "Enforce sparsity S"); + static_assert(not is_sparse_ptr::value, "Enforce sparse singleton"); + + template + CUTE_HOST_DEVICE constexpr + sparse_ptr operator+(Index const& i) const { + // Only allow offset by multiples of the sparsity factor, + // else the misalignments become a bug. E.g. (sparse_ptr<8,I>{} + 7) + 7 + // Motivation for subsparse_iterator or generalization of subbyte_iterator? + assert(i % Sparsity == 0); + return {this->get() + i / Sparsity}; + } + + template + CUTE_HOST_DEVICE constexpr + reference operator[](Index const& i) const { + // Allow offset by any value and dereference. + // Not implemented in terms of sparse_ptr::op+() + return *(this->get() + i / Sparsity); + } +}; + +template +struct is_sparse_ptr> : true_type {}; + +template +CUTE_HOST_DEVICE constexpr +auto +make_sparse_ptr(Iter const& iter) { + if constexpr (Sparsity == 1) { + return iter; + } else { + return sparse_ptr{iter}; + } + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +recast_ptr(sparse_ptr const& ptr) { + static_assert(not is_sparse::value); + return recast_ptr(ptr.get()); +} + +// +// Display utilities +// + +template +CUTE_HOST_DEVICE void print(sparse_ptr ptr) +{ + printf("sparse<%d>_", S); print(ptr.get()); +} + +#if !defined(__CUDACC_RTC__) +template +CUTE_HOST std::ostream& operator<<(std::ostream& os, sparse_ptr ptr) +{ + return os << "sparse<" << S << ">_" << ptr.get(); +} +#endif + +} // end namespace cute diff --git a/include/cute/pointer_swizzle.hpp b/include/cute/pointer_swizzle.hpp new file mode 100644 index 0000000000..1a802cfdc6 --- /dev/null +++ b/include/cute/pointer_swizzle.hpp @@ -0,0 +1,176 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTE_HOST_DEVICE +#include // cute::iter_adaptor +#include // cute::Swizzle, cute::get_swizzle primary template +#include // cute::iterator_traits +#include // cute::subbyte_iterator + +/* This implements a swizzle pointer of the form + * InvolutionFn o PtrAdd + * where the InvolutionFn need not be linear. + * + * This differs subtly from swizzle_layout because the smem pointer is used + * as the offset. That means that swizzle_layout will implement position-independent + * swizzle layouts, while swizzle_ptr implements position-dependent swizzle tensors. + * Arch chose to design hardware with position-dependent swizzles. + * + * For clarity: + * NormalLayout : DeRef <- PtrAdd <- [Layout] + * ComposedLayout: DeRef <- PtrAdd <- [Swizzle <- OffsetAdd <- Layout] + * SwizzlePtr : [DeRef <- Swizzle <- PtrAdd] <- Layout + * + * Furthermore, for known swizzles, this pointer attempts to decay itself + * to a normal-pointer with a new layout containing dynamic or static strides. + * This is possible by determining the subdomain of the InvolutionFn + * that is identity and testing if the Layout's codomain is contained + * within it. + */ + +namespace cute +{ + +// concept SwizzleFn { +// CUTE_HOST_DEVICE constexpr static uint apply(uint); +// } +// See Swizzle in swizzle.hpp for common swizzle-functions. + +template +struct swizzle_ptr : iter_adaptor> +{ + using iterator = Iterator; + using reference = typename iterator_traits::reference; + using element_type = typename iterator_traits::element_type; + using value_type = typename iterator_traits::value_type; + + using iter_adaptor>::iter_adaptor; + + template + CUTE_HOST_DEVICE constexpr static + Iter apply_swizzle(Iter ptr) { + return {apply_swizzle(ptr.get())}; + } + + template + CUTE_HOST_DEVICE constexpr static + T* apply_swizzle(T* ptr) { + return reinterpret_cast(SwizzleFn::apply(reinterpret_cast(ptr))); + } + + template + CUTE_HOST_DEVICE constexpr static + subbyte_iterator apply_swizzle(subbyte_iterator ptr) { + return {apply_swizzle(ptr.ptr_), ptr.idx_}; + } + + CUTE_HOST_DEVICE constexpr + reference operator*() const { + return *apply_swizzle(this->get()); + } + + template + CUTE_HOST_DEVICE constexpr + reference operator[](Int const& i) const { + return *apply_swizzle(this->get() + i); + } +}; + +// +// Helper Function +// +template // Found the SwizzleFn +struct get_swizzle> { using type = SwizzleFn; }; +template // Recurse into anything with a ::iterator +struct get_swizzle> : get_swizzle {}; + +template +CUTE_HOST_DEVICE constexpr +swizzle_ptr +make_swizzle_ptr(Iterator ptr, SwizzleFn) { + return {ptr}; +} + +// Swizzle-0 specialization for immediate decay +template +CUTE_HOST_DEVICE constexpr +Iterator +make_swizzle_ptr(Iterator ptr, Swizzle<0,M,S>) { + return ptr; +} + +// +// Recast +// + +template +CUTE_HOST_DEVICE constexpr +auto +raw_pointer_cast(swizzle_ptr const& ptr) { + return raw_pointer_cast(ptr.get()); +} + +// SwizzleFn operates on the pointer address, so it doesn't care about the type +template +CUTE_HOST_DEVICE constexpr +auto +recast_ptr(swizzle_ptr const& ptr) { + return make_swizzle_ptr(recast_ptr(ptr.get()), SwizzleFn{}); +} + +// The statically-known alignment of a swizzle pointer is the alignment of the swizzle function converted to bits +template +CUTE_HOST_DEVICE constexpr +auto +max_alignment(swizzle_ptr const&) { + return Int<8>{} * max_alignment(SwizzleFn{}); +} + +// +// Display utilities +// + +template +CUTE_HOST_DEVICE void print(swizzle_ptr ptr) +{ + print(SwizzleFn{}); printf("_"); print(ptr.get()); +} + +#if !defined(__CUDACC_RTC__) +template +CUTE_HOST std::ostream& operator<<(std::ostream& os, swizzle_ptr ptr) +{ + return os << SwizzleFn{} << "_" << ptr.get(); +} +#endif + +} // end namespace cute diff --git a/include/cute/stride.hpp b/include/cute/stride.hpp new file mode 100644 index 0000000000..f2d31f4e34 --- /dev/null +++ b/include/cute/stride.hpp @@ -0,0 +1,598 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTE_HOST_DEVICE +#include // cute::__CUTE_REQUIRES +#include // cute::is_tuple +#include // cute::is_integral +#include // cute::seq +#include // cute::divmod +#include // cute::basis_get +#include // cute::identity +#include // cute::fold +#include // cute::is_congruent + +namespace cute +{ + +/** crd2idx(c,s,d) maps a coordinate within to an index + * + * This is computed as follows: + * [coord, shape, and stride are all integers => step forward by stride] + * op(c, s, d) => c * d + * [coord is integer, shape and stride are tuple => divmod coord for each mode] + * op(c, (s,S), (d,D)) => op(c % prod(s), s, d) + op(c / prod(s), (S), (D)) + * [coord, shape, and stride are all tuples => consider each mode independently] + * op((c,C), (s,S), (d,D)) => op(c, s, d) + op((C), (S), (D)) + */ +template +CUTE_HOST_DEVICE constexpr +auto +crd2idx(Coord const& coord, + Shape const& shape, + Stride const& stride); + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +crd2idx_ttt(Coord const& coord, + Shape const& shape, + Stride const& stride, seq) +{ + return (... + crd2idx(get(coord), get(shape), get(stride))); +} + +template +CUTE_HOST_DEVICE constexpr +auto +crd2idx_itt(CInt const& coord, + STuple const& shape, + DTuple const& stride, seq) +{ + if constexpr (sizeof...(Is) == 0) { // Avoid recursion and mod on single/last iter + return crd2idx(coord, get(shape), get(stride)); + } else if constexpr (is_constant<0, CInt>::value) { + return crd2idx(_0{}, get(shape), get(stride)) + + (_0{} + ... + crd2idx(_0{}, get(shape), get(stride))); + } else { // General case + auto [div, mod] = divmod(coord, product(get(shape))); + return crd2idx(mod, get(shape), get(stride)) + + crd2idx_itt(div, shape, stride, seq{}); + } + + CUTE_GCC_UNREACHABLE; +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +crd2idx(Coord const& coord, + Shape const& shape, + Stride const& stride) +{ + if constexpr (is_tuple::value) { + if constexpr (is_tuple::value) { // tuple tuple tuple + static_assert(tuple_size::value == tuple_size< Shape>::value, "Mismatched Ranks"); + static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); + return detail::crd2idx_ttt(coord, shape, stride, tuple_seq{}); + } else { // tuple "int" "int" + static_assert(sizeof(Coord) == 0, "Invalid parameters"); + } + } else { + if constexpr (is_tuple::value) { // "int" tuple tuple + static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); + return detail::crd2idx_itt(coord, shape, stride, tuple_seq{}); + } else { // "int" "int" "int" + return coord * stride; + } + } + + CUTE_GCC_UNREACHABLE; +} + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +crd2idx_horner(CTuple const& coord, + STuple const& shape, seq) +{ + if constexpr (sizeof...(Is) == 0) { // No recursion on single/last iter + return get(coord); + } else { // General case + return get(coord) + get(shape) * crd2idx_horner(coord, shape, seq{}); + } + + CUTE_GCC_UNREACHABLE; +} + +} // end namespace detail + +/** crd2idx(c,s) maps a coordinate within Shape to an index + * via a colexicographical enumeration of coordinates in Shape. + * i = c0 + s0 * (c1 + s1 * (c2 + s2 * ...)) + */ +template +CUTE_HOST_DEVICE constexpr +auto +crd2idx(Coord const& coord, + Shape const& shape) +{ + if constexpr (is_integral::value) { // Coord is already an index + return coord; + } else if constexpr (is_integral::value) { + static_assert(dependent_false, "Invalid parameters"); + } else { // Make congruent, flatten, and apply Horner's method + static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); + auto flat_coord = flatten(coord); + auto flat_shape = flatten(product_like(shape, coord)); + return detail::crd2idx_horner(flat_coord, flat_shape, tuple_seq{}); + } + + CUTE_GCC_UNREACHABLE; +} + +/** idx2crd(i,s,d) splits an index into a coordinate within . + * + * This is computed as follows: + * [index, shape, and stride are all integers => determine 1D coord] + * op(i, s, d) => (i / d) % s + * [index is integer, shape and stride are tuple => determine component for each mode] + * op(i, (s,S), (d,D)) => (op(i, s, d), op(i, S, D)...) + * [index, shape, and stride are all tuples => consider each mode independently] + * op((i,I), (s,S), (d,D)) => (op(i, s, d), op((I), (S), (D))) + * + * NOTE: This only works for compact shape+stride layouts. A more general version would + * apply to all surjective layouts + */ +template +CUTE_HOST_DEVICE constexpr +auto +idx2crd(Index const& idx, + Shape const& shape, + Stride const& stride) +{ + if constexpr (is_tuple::value) { + if constexpr (is_tuple::value) { // tuple tuple tuple + static_assert(tuple_size::value == tuple_size< Shape>::value, "Mismatched Ranks"); + static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); + return transform(idx, shape, stride, [](auto const& i, auto const& s, auto const& d){ return idx2crd(i,s,d); }); + } else { // tuple "int" "int" + static_assert(sizeof(Index) == 0, "Invalid parameters"); + } + } else { + if constexpr (is_tuple::value) { + if constexpr (is_tuple::value) { // "int" tuple tuple + static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); + return transform(shape, stride, [&](auto const& s, auto const& d){ return idx2crd(idx,s,d); }); + } else { // "int" tuple "int" + return transform(shape, compact_col_major(shape, stride), [&](auto const& s, auto const& d){ return idx2crd(idx,s,d); }); + } + } else { // "int" "int" "int" + if constexpr (is_constant<1, Shape>::value) { + // Skip potential stride-0 division + return Int<0>{}; + } else { + return (idx / stride) % shape; + } + } + } + + CUTE_GCC_UNREACHABLE; +} + +/** idx2crd(i,s) splits an index into a coordinate within Shape + * via a colexicographical enumeration of coordinates in Shape. + * c0 = (idx / 1) % s0 + * c1 = (idx / s0) % s1 + * c2 = (idx / (s0 * s1)) % s2 + * ... + */ +template +CUTE_HOST_DEVICE constexpr +auto +idx2crd(Index const& idx, + Shape const& shape) +{ + if constexpr (is_tuple::value) { + if constexpr (is_tuple::value) { // tuple tuple + static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); + return transform(idx, shape, [](auto const& i, auto const& s) { return idx2crd(i,s); }); + } else { // tuple "int" + static_assert(sizeof(Index) == 0, "Invalid parameters"); + } + } else { + if constexpr (is_tuple::value) { // "int" tuple + return transform_leaf(as_arithmetic_tuple(crd2idx(idx, shape, make_basis_like(shape))), identity{}); + } else { // "int" "int" + return idx; + } + } + + CUTE_GCC_UNREACHABLE; +} + +// +// crd2crd +// + +template +CUTE_HOST_DEVICE constexpr +auto +crd2crd(Coord const& coord, + SShape const& src_shape, + DShape const& dst_shape) +{ + if constexpr (is_tuple::value && is_tuple::value && is_tuple::value) { + static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); + static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); + return transform(coord, src_shape, dst_shape, [](auto const& c, auto const& s, auto const& d) { return crd2crd(c,s,d); }); + } else { + // assert(size(src_shape) == size(dst_shape)) + return idx2crd(crd2idx(coord, src_shape), dst_shape); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Compact Major +// + +// Tags for common layouts and dispatching +struct LayoutLeft; // Col-major layout mapping; leftmost extent has stride 1 +using GenColMajor = LayoutLeft; // Alias + +struct LayoutRight; // Row-major layout mapping; rightmost extent has stride 1 +using GenRowMajor = LayoutRight; // Alias + +namespace detail { + +// For GCC8.5 -- Use of lambdas in unevaluated contexts. Instead use function objects. +template +struct CompactLambda; + +// @pre is_integral +// Return (result, current * product(shape)) to enable recurrence +template +CUTE_HOST_DEVICE constexpr +auto +compact(Shape const& shape, + Current const& current) +{ + if constexpr (is_tuple::value) { // Shape::tuple Current::int + using Lambda = CompactLambda; // Append or Prepend + using Seq = typename Lambda::template seq; // Seq or RSeq + return cute::detail::fold(shape, cute::make_tuple(cute::make_tuple(), current), Lambda{}, Seq{}); + } else { // Shape::int Current::int + if constexpr (is_constant<1, Shape>::value) { + return cute::make_tuple(Int<0>{}, current); // If current is dynamic, this could save a reg + } else { + return cute::make_tuple(current, current * shape); + } + } + + CUTE_GCC_UNREACHABLE; +} + +// For GCC8.5 -- Specialization LayoutLeft +template <> +struct CompactLambda +{ + template + CUTE_HOST_DEVICE constexpr auto + operator()(Init const& init, Shape const& si) { + auto result = detail::compact(si, get<1>(init)); + return cute::make_tuple(append(get<0>(init), get<0>(result)), get<1>(result)); // Append + } + + template + using seq = tuple_seq; // Seq +}; + +// For GCC8.5 -- Specialization LayoutRight +template <> +struct CompactLambda +{ + template + CUTE_HOST_DEVICE constexpr auto + operator()(Init const& init, Shape const& si) { + auto result = detail::compact(si, get<1>(init)); + return cute::make_tuple(prepend(get<0>(init), get<0>(result)), get<1>(result)); // Prepend + } + + template + using seq = tuple_rseq; // RSeq +}; + +} // end namespace detail + +template , + __CUTE_REQUIRES(is_tuple::value || is_integral::value)> +CUTE_HOST_DEVICE constexpr +auto +compact_major(Shape const& shape, + Current const& current = {}) +{ + if constexpr (is_tuple::value) { // Shape::tuple Current::tuple + static_assert(is_tuple::value, "Invalid parameters"); + static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); + // Recurse to apply to the terminals of current + return transform(shape, current, [&](auto const& s, auto const& c){ return compact_major(s,c); }); + } else { + return get<0>(detail::compact(shape, current)); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Compact Col Major +// + +struct LayoutLeft { + template + using Apply = decltype(compact_major(declval())); +}; + +template > +CUTE_HOST_DEVICE constexpr +auto +compact_col_major(Shape const& shape, + Current const& current = {}) +{ + return compact_major(shape, current); +} + +// +// Compact Row Major +// + +struct LayoutRight { + template + using Apply = decltype(compact_major(declval())); +}; + +template > +CUTE_HOST_DEVICE constexpr +auto +compact_row_major(Shape const& shape, + Current const& current = {}) +{ + return compact_major(shape, current); +} + +// +// Compact Order -- compute a compact stride based on an ordering of the modes +// + +namespace detail { + +// @pre weakly_congruent(order, shape) +// @pre is_congruent +// @pre is_static +// @pre is_static +template +CUTE_HOST_DEVICE constexpr +auto +compact_order(Shape const& shape, Order const& order, + RefShape const& ref_shape, RefOrder const& ref_order) +{ + if constexpr (is_tuple::value) { + static_assert(tuple_size::value == tuple_size::value, "Need equal rank of shape and order"); + return transform(shape, order, [&](auto const& s, auto const& o) { return compact_order(s, o, ref_shape, ref_order); }); + } else { + // Compute the starting stride for this shape by accumulating all shapes corresponding to lesser orders + auto stride_start = product(transform(ref_shape, ref_order, + [&](auto const& s, auto const& o) { + return conditional_return(o < order, s, Int<1>{}); + })); + return compact_col_major(shape, stride_start); + } + + CUTE_GCC_UNREACHABLE; +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +compact_order(Shape const& shape, Order const& order) +{ + auto ref_shape = flatten_to_tuple(product_like(shape, order)); + + auto flat_order = flatten_to_tuple(order); + // Find the largest static element of order + auto max_order = cute::fold(flat_order, Int<0>{}, [](auto v, auto order) { + if constexpr (is_constant::value) { + return order; + } else { + return v; + } + + CUTE_GCC_UNREACHABLE; + }); + // Replace any dynamic elements within order with large-static elements + auto max_seq = make_range{}; + auto ref_order = cute::transform(max_seq, flat_order, [](auto seq_v, auto order) { + if constexpr (is_static::value) { + return order; + } else { + return seq_v; + } + + CUTE_GCC_UNREACHABLE; + }); + + auto new_order = unflatten(ref_order, order); + + return detail::compact_order(shape, new_order, ref_shape, ref_order); +} + +template +CUTE_HOST_DEVICE constexpr +auto +compact_order(Shape const& shape, GenColMajor const& major) +{ + return compact_major(shape); +} + +template +CUTE_HOST_DEVICE constexpr +auto +compact_order(Shape const& shape, GenRowMajor const& major) +{ + return compact_major(shape); +} + +// +// Coordinate iterator +// + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +void +increment(Coord& coord, Shape const& shape, Order const& order) +{ + ++basis_get(get<0>(order), coord); + cute::for_each(make_range<1, tuple_size::value>{}, [&](auto i){ + if (basis_get(get(order), coord) == basis_get(get(order), shape)) { + basis_get(get(order), coord) = 0; + ++basis_get(get(order), coord); + } + }); +} + +/** Increment a (dynamic) coord colexicographically within a shape + * @pre is_congruent::value + * \code + * auto shape = make_shape(1,2,make_shape(2,3),3); + * auto coord = repeat_like(shape, 0); + * + * for (int i = 0; i < size(shape); ++i) { + * std::cout << i << ": " << coord << std::endl; + * increment(coord, shape); + * } + * \endcode + */ +template +CUTE_HOST_DEVICE constexpr +void +increment(Coord& coord, Shape const& shape) +{ + increment(coord, shape, flatten_to_tuple(make_basis_like(shape))); +} + +} // end namespace detail + +struct ForwardCoordIteratorSentinel +{}; + +// A forward iterator for a starting coordinate in a shape's domain, and a shape. +// The starting coordinate may be zero but need not necessarily be. +template +struct ForwardCoordIterator +{ + static_assert(is_congruent::value); + + CUTE_HOST_DEVICE constexpr + Coord const& operator*() const { return coord; } + CUTE_HOST_DEVICE constexpr + ForwardCoordIterator& operator++() { detail::increment(coord, shape, Order{}); return *this; } + // Sentinel for the end of the implied range + CUTE_HOST_DEVICE constexpr + bool operator==(ForwardCoordIteratorSentinel const&) const { return basis_get(back(Order{}), coord) == basis_get(back(Order{}), shape); } + CUTE_HOST_DEVICE constexpr + bool operator!=(ForwardCoordIteratorSentinel const&) const { return basis_get(back(Order{}), coord) != basis_get(back(Order{}), shape); } + // NOTE: These are expensive, avoid use + CUTE_HOST_DEVICE constexpr + bool operator==(ForwardCoordIterator const& other) const { return coord == other.coord; } + CUTE_HOST_DEVICE constexpr + bool operator!=(ForwardCoordIterator const& other) const { return coord != other.coord; } + + Coord coord; + Shape const& shape; +}; + +// A forward iterator for a coordinate that starts from a provided coordinate and increments in a prescribed order +template +CUTE_HOST_DEVICE constexpr +auto +make_coord_iterator(Coord const& coord, Shape const& shape) +{ + static_assert(is_congruent::value); + static_assert(is_congruent::value); + static_assert(is_congruent::value); + auto flat_order = flatten_to_tuple(Order{}); + auto inv_order = transform(make_seq{}, [&](auto i){ return find(flat_order, i); }); + auto basis_order = transform_leaf(inv_order, [&](auto i) { return get(flatten_to_tuple(make_basis_like(shape))); }); + return ForwardCoordIterator{coord,shape}; +} + +// A forward iterator for a coordinate that starts from a provided coordinate and increments colex +template +CUTE_HOST_DEVICE constexpr +auto +make_coord_iterator(Coord const& coord, Shape const& shape) +{ + static_assert(is_congruent::value); + auto basis_order = flatten_to_tuple(make_basis_like(shape)); + return ForwardCoordIterator{coord,shape}; +} + +// A forward iterator for a coordinate that starts from zero and increments in a prescribed order +template +CUTE_HOST_DEVICE constexpr +auto +make_coord_iterator(Shape const& shape) +{ + return make_coord_iterator(repeat_like(shape, int(0)), shape); +} + +// A forward iterator for a coordinate that starts from zero and increments colex +template +CUTE_HOST_DEVICE constexpr +auto +make_coord_iterator(Shape const& shape) +{ + return make_coord_iterator(repeat_like(shape, int(0)), shape); +} + +} // end namespace cute diff --git a/include/cute/swizzle.hpp b/include/cute/swizzle.hpp new file mode 100644 index 0000000000..52abf856dd --- /dev/null +++ b/include/cute/swizzle.hpp @@ -0,0 +1,498 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTE_HOST_DEVICE +#include // cute::is_tuple +#include // cute::constant +#include // cute::max, cute::min +#include // cute::transform_apply + +namespace cute +{ + +// A generic Swizzle functor +/* 0bxxxxxxxxxxxxxxxYYYxxxxxxxZZZxxxx + * ^--^ MBase is the number of least-sig bits to keep constant + * ^-^ ^-^ BBits is the number of bits in the mask + * ^---------^ SShift is the distance to shift the YYY mask + * (pos shifts YYY to the right, neg shifts YYY to the left) + * + * e.g. Given + * 0bxxxxxxxxxxxxxxxxYYxxxxxxxxxZZxxx + * the result is + * 0bxxxxxxxxxxxxxxxxYYxxxxxxxxxAAxxx where AA = ZZ xor YY + */ +template +struct Swizzle +{ + static constexpr int num_bits = BBits; + static constexpr int num_base = MBase; + static constexpr int num_shft = SShift; + + static_assert(num_base >= 0, "MBase must be positive."); + static_assert(num_bits >= 0, "BBits must be positive."); + static_assert(abs(num_shft) >= num_bits, "abs(SShift) must be more than BBits."); + + // using 'int' type here to avoid unintentially casting to unsigned... unsure. + using bit_msk = cute::constant; + using yyy_msk = cute::constant; + using zzz_msk = cute::constant; + using msk_sft = cute::constant; + + static constexpr uint32_t swizzle_code = uint32_t(yyy_msk{} | zzz_msk{}); + + template + CUTE_HOST_DEVICE constexpr static + auto + apply(Offset const& offset) + { + return offset ^ shiftr(offset & yyy_msk{}, msk_sft{}); // ZZZ ^= YYY + } + + template + CUTE_HOST_DEVICE constexpr + auto + operator()(Offset const& offset) const + { + return apply(offset); + } + + template + CUTE_HOST_DEVICE constexpr + auto + operator==(Swizzle const&) const + { + return B == BBits && M == MBase && S == SShift; + } +}; + +// +// make_swizzle<0b1000, 0b0100>() -> Swizzle<1,2,1> +// make_swizzle<0b11000000, 0b00000110>() -> Swizzle<2,1,5> +// + +template +CUTE_HOST_DEVICE constexpr +auto +make_swizzle() +{ + constexpr uint32_t BZ = popcount(Y); // Number of swizzle bits + constexpr uint32_t BY = popcount(Z); // Number of swizzle bits + static_assert(BZ == BY, "Number of bits in Y and Z don't match"); + constexpr uint32_t TZ_Y = countr_zero(Y); // Number of trailing zeros in Y + constexpr uint32_t TZ_Z = countr_zero(Z); // Number of trailing zeros in Z + constexpr uint32_t M = cute::min(TZ_Y, TZ_Z) % 32; + constexpr int32_t S = int32_t(TZ_Y) - int32_t(TZ_Z); // Difference in trailing zeros + static_assert((Y | Z) == Swizzle::swizzle_code, "Something went wrong."); + return Swizzle{}; +} + +template +CUTE_HOST_DEVICE constexpr +auto +composition(Swizzle, Swizzle) +{ + static_assert(S0 == S1, "Can only merge swizzles of the same shift."); + constexpr uint32_t Y = Swizzle::yyy_msk::value ^ Swizzle::yyy_msk::value; + constexpr uint32_t Z = Swizzle::zzz_msk::value ^ Swizzle::zzz_msk::value; + return make_swizzle(); + + //return ComposedFn, Swizzle>{}; +} + +// +// Utility for slicing and swizzle "offsets" +// + +// For swizzle functions, it is often needed to keep track of which bits are +// consumed and which bits are free. Furthermore, it is useful to know whether +// each of these bits is known statically or dynamically. + +// MixedBits is an 32-bit unsigned integer class where some bits are known statically +// and some bits are known dynamically. These sets of bits are disjoint and it is +// known statically which bits are known dynamically. + +// MixedBits can only be manipulated through bitwise operations + +// Abstract value: StaticInt | (dynamic_int_ & StaticFlags) +template // 0: static, 1: dynamic +struct MixedBits +{ + // Representation invariants + static_assert(StaticFlags != 0, "Should be at least one dynamic bit in MixedBits."); + static_assert((StaticInt & StaticFlags) == 0, "No static/dynamic overlap allowed in MixedBits."); + + uint32_t dynamic_int_; + // assert((dynamic_int_ & ~StaticFlags) == 0); + + CUTE_HOST_DEVICE constexpr operator uint32_t() const noexcept { return StaticInt | dynamic_int_; } +}; + +// Return a value representing (C{} | (d & C)) potentially using MixedBits to track s and f. +// This maker does allow ((s & f) != 0) and enforces the MixedBits invariant before creation. +template +CUTE_HOST_DEVICE constexpr +auto +make_mixed_bits(C, DynamicType const& d, C) +{ + static_assert(is_integral::value); + constexpr uint32_t new_f = uint32_t(f) & ~uint32_t(s); // StaticBits take precedence, M<0,f>{d} | C{} + if constexpr (new_f == 0 || is_static::value) { + return C{} | (d & C{}); // Just return a static int + } else { + return MixedBits{uint32_t(d) & new_f}; // MixedBits + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Operators +// + +// Equality +template +CUTE_HOST_DEVICE constexpr +auto +operator==(MixedBits const& m, C) +{ + return (S0 == (uint32_t(S1) & ~F0)) && (m.dynamic_int_ == (uint32_t(S1) & F0)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator==(C s, MixedBits const& m) +{ + return m == s; +} + +// Bitwise AND +template +CUTE_HOST_DEVICE constexpr +auto +operator&(MixedBits const& m0, MixedBits const& m1) +{ + // Truth table for (S0,D0,F0) & (S1,D1,F1) -> (S,D,F) + // S0D0F0 | 0X0 | 001 | 011 | 1X0 | + // S1D1F1 + // 0X0 | 0X0 | 0X0 | 0X0 | 0X0 | + // 001 | 0X0 | 001 | 001 | 001 | + // 011 | 0X0 | 001 | 011 | 011 | + // 1X0 | 0X0 | 001 | 011 | 1X0 | + + return make_mixed_bits(C{}, + //(S0 | m0.dynamic_int_) & (S1 | m1.dynamic_int_), + ((S1 & F0) & m0.dynamic_int_) | ((S0 & F1) & m1.dynamic_int_) | (m0.dynamic_int_ & m1.dynamic_int_), + C<(S1 & F0) | (S0 & F1) | (F0 & F1)>{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator&(MixedBits const& m, C) +{ + return make_mixed_bits(C{}, + m.dynamic_int_, + C{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator&(C s, MixedBits const& m) +{ + return m & s; +} + +// Bitwise OR +template +CUTE_HOST_DEVICE constexpr +auto +operator|(MixedBits const& m0, MixedBits const& m1) +{ + // Truth table for (S0,D0,F0) | (S1,D1,F1) -> (S,D,F) + // S0D0F0 | 0X0 | 001 | 011 | 1X0 | + // S1D1F1 + // 0X0 | 0X0 | 001 | 011 | 1X0 | + // 001 | 001 | 001 | 011 | 1X0 | + // 011 | 011 | 011 | 011 | 1X0 | + // 1X0 | 1X0 | 1X0 | 1X0 | 1X0 | + + return make_mixed_bits(C{}, + ((~S1 & F0) & m0.dynamic_int_) | ((~S0 & F1) & m1.dynamic_int_), + C<(~S0 & F1) | (~S1 & F0)>{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator|(MixedBits const& m, C) +{ + return make_mixed_bits(C{}, + m.dynamic_int_, + C{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator|(C s, MixedBits const& m) +{ + return m | s; +} + +// Bitwise XOR +template +CUTE_HOST_DEVICE constexpr +auto +operator^(MixedBits const& m0, MixedBits const& m1) +{ + // Truth table for (S0,D0,F0) ^ (S1,D1,F1) -> (S,D,F) + // S0D0F0 | 0X0 | 001 | 011 | 1X0 | + // S1D1F1 + // 0X0 | 0X0 | 001 | 011 | 1X0 | + // 001 | 001 | 001 | 011 | 011 | + // 011 | 011 | 011 | 001 | 001 | + // 1X0 | 1X0 | 011 | 001 | 0X0 | + + return make_mixed_bits(C<(~S0 & S1 & ~F0) | (S0 & ~S1 & ~F1)>{}, + (S0 | m0.dynamic_int_) ^ (S1 | m1.dynamic_int_), + C{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator^(MixedBits const& m, C) +{ + return make_mixed_bits(C<(~S0 & uint32_t(S1) & ~F0) | (S0 & ~uint32_t(S1))>{}, + (S0 | m.dynamic_int_) ^ uint32_t(S1), + C{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator^(C s, MixedBits const& m) +{ + return m ^ s; +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator<<(MixedBits const& m, C) +{ + return make_mixed_bits(C<(S0 << S1)>{}, + m.dynamic_int_ << S1, + C<(F0 << S1)>{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +operator>>(MixedBits const& m, C) +{ + return make_mixed_bits(C<(S0 >> S1)>{}, + m.dynamic_int_ >> S1, + C<(F0 >> S1)>{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +shiftl(MixedBits const& m, C s) +{ + if constexpr (S1 >= 0) { + return m << s; + } else { + return m >> -s; + } +} + +template +CUTE_HOST_DEVICE constexpr +auto +shiftr(MixedBits const& m, C s) +{ + if constexpr (S1 >= 0) { + return m >> s; + } else { + return m << -s; + } +} + +// +// Upcast and Downcast +// + +template +CUTE_HOST_DEVICE constexpr +auto +safe_div(MixedBits const& m, C s) +{ + static_assert(has_single_bit(uint32_t(S1)), "Only divide MixedBits by powers of two."); + return make_mixed_bits(safe_div(C{}, s), + safe_div(m.dynamic_int_, s), + safe_div(C{}, s)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +upcast(MixedBits const& m) +{ + static_assert(has_single_bit(N), "Only divide MixedBits by powers of two."); + return safe_div(m, C{}); +} + +template ::value)> +CUTE_HOST_DEVICE constexpr +auto +upcast(T const& m) +{ + return safe_div(m, C{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +downcast(MixedBits const& m) +{ + static_assert(has_single_bit(N), "Only scale MixedBits by powers of two."); + return make_mixed_bits(C{}, + m.dynamic_int_ * N, + C{}); +} + +template ::value)> +CUTE_HOST_DEVICE constexpr +auto +downcast(T const& m) +{ + return m * C{}; +} + +template +CUTE_HOST_DEVICE constexpr +auto +max_alignment(MixedBits const&) +{ + return C{}; +} + +template +CUTE_HOST_DEVICE constexpr +C +max_alignment(C const& c) +{ + return c; +} + +// +// Convert a Pow2Layout+Coord to a MixedBits +// + +template +CUTE_HOST_DEVICE constexpr +auto +to_mixed_bits(Shape const& shape, Stride const& stride, Coord const& coord) +{ + if constexpr (is_tuple::value && is_tuple::value && is_tuple::value) { + static_assert(tuple_size::value == tuple_size::value, "Mismatched ranks"); + static_assert(tuple_size::value == tuple_size::value, "Mismatched ranks"); + return transform_apply(shape, stride, coord, [](auto const& s, auto const& d, auto const& c) { return to_mixed_bits(s,d,c); }, + [](auto const&... a) { return (a ^ ...); }); + } else if constexpr (is_integral::value && is_integral::value && is_integral::value) { + static_assert(decltype(shape*stride)::value == 0 || has_single_bit(decltype(shape*stride)::value), "Requires pow2 shape*stride."); + return make_mixed_bits(Int<0>{}, coord * stride, (shape - Int<1>{}) * stride); + } else { + static_assert(is_integral::value && is_integral::value && is_integral::value, "Either Shape, Stride, and Coord must be all tuples, or they must be all integral (in the sense of cute::is_integral)."); + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +to_mixed_bits(Layout const& layout, Coord const& coord) +{ + return to_mixed_bits(layout.shape(), layout.stride(), idx2crd(coord, layout.shape())); +} + +// +// Display utilities +// + +template +CUTE_HOST_DEVICE void print(Swizzle const&) +{ + printf("Sw<%d,%d,%d>", B, M, S); +} + +template +CUTE_HOST_DEVICE void print(MixedBits const& m) +{ + printf("M_%u|(%u&%u)=%u", S, m.dynamic_int_, F, uint32_t(m)); +} + +#if !defined(__CUDACC_RTC__) +template +CUTE_HOST std::ostream& operator<<(std::ostream& os, Swizzle const&) +{ + return os << "Sw<" << B << "," << M << "," << S << ">"; +} + +template +CUTE_HOST std::ostream& operator<<(std::ostream& os, MixedBits const& m) +{ + return os << "M_" << S << "|(" << m.dynamic_int_ << "&" << F << ")=" << uint32_t(m); +} +#endif // !defined(__CUDACC_RTC__) + +// +// Helper Function +// +template // Default No-Swizzle +struct get_swizzle { using type = Swizzle<0,4,3>; }; + +template +using get_swizzle_t = typename get_swizzle::type; + +} // end namespace cute diff --git a/include/cute/swizzle_layout.hpp b/include/cute/swizzle_layout.hpp new file mode 100644 index 0000000000..7f7161bc32 --- /dev/null +++ b/include/cute/swizzle_layout.hpp @@ -0,0 +1,584 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTE_HOST_DEVICE +#include // cute::Layout +#include // cute::ComposedLayout +#include // cute::Swizzle, cute::get_swizzle primary template + +/* Specialized functionality for a ComposedLayout of the form + * InvolutionFn o Offset o LayoutB + * where the InvolutionFn is a Swizzle and is not linear (hence the need for the Offset). + * + * Because these are specializations for core functions of ComposedLayout, these Swizzle Layouts + * provide similar functionality to Layout including tiling, partitioning, + * coordinate-to-index mapping and layout manipulations, but are not considered "normal" layouts. + * For example, these provide shape() and size() functions, but do not provide stride() functions. + * + * Furthermore, each of these specializations uses Swizzle<>-specific knowledge in its implementation and + * attempts to decay itself to a normal-layout with dynamic or static strides when certain slicing conditions + * are met. This is possible by determining the subdomain of the Swizzle<> function that is identity and + * testing if LayoutB's codomain is contained within it. In general, MizedBits is used as the Offset to track + * statically-vs-dynamically known bits in the Offset to improve the decay to static or dynamic normal layouts. + */ + +namespace cute +{ + +// +// Helper Function +// +template +struct get_swizzle,Offset,LayoutB>> { using type = Swizzle; }; + +// +// Constructors +// + +template +CUTE_HOST_DEVICE constexpr +auto +make_layout(Swizzle const& sxor) +{ + return composition(sxor, Layout,Int<1>>{}); +} + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +transfer_swizzle(Layout const& old_layout, + Layout const& new_layout) +{ + // Our goal is to determine a new swizzle for the strides in new_layout for consistent vectorizations + + // This is accomplished by identifying + // S o L :=: S? o L* + // We identify the "active" portion of S by computing (P o L)(c*) where P is a projection generated by S + // Then that active identifier is transformed through the layouts: + // L*(L[(P o L)(c*)]) + // which is a new swizzle identifier for S?, the new swizzle + + // Projections of the swizzle layout for composition, P + auto swizzle_only_zy = make_layout(make_shape (Int<(1 << M)>{}, Int<(1 << B)>{}, Int<(1 << (abs(S)-B))>{}, Int<(1 << B )>{}, Int<1>{}), + make_stride( Int<0>{}, Int<(1 << M)>{}, Int<0>{}, Int<(1 << (M+abs(S)))>{}, Int<0>{})); + + // Compose with the tile to get the swizzle projection, P o L [The Z and Y contributing portions of L] + auto layout_only_zy = composition(swizzle_only_zy, old_layout); + // Transform the end coordinate to get the active bits of the swizzle, (P o L)(c*) + auto swizzle_active_bits = layout_only_zy(size(layout_only_zy)-Int<1>{}); + + // Get the Z bit and the Y bits -- keep only those that are active in Z *and* Y + auto zzz_msk = typename Swizzle::zzz_msk{}; + auto yyy_msk = typename Swizzle::yyy_msk{}; + auto msk_sft = typename Swizzle::msk_sft{}; + auto active_Z = swizzle_active_bits & shiftr(swizzle_active_bits, msk_sft) & zzz_msk; + auto active_Y = swizzle_active_bits & shiftr(swizzle_active_bits, -msk_sft) & yyy_msk; + + // Pass the identifiers through the old layout and new layout to make a new swizzle identifier, L*(L[(P o L)(c*)]) + auto new_active_Z = new_layout(old_layout.get_1d_coord(active_Z)); + auto new_active_Y = new_layout(old_layout.get_1d_coord(active_Y)); + + // Use this new swizzle identifier to construct the new swizzle for new_layout + // (this also makes sure it's a "valid" swizzle that Swizzle can represent) + return composition(make_swizzle(), new_layout); +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +make_fragment_like(ComposedLayout,Offset,Layout> const& layout) +{ + return make_fragment_like(layout.layout_b()); +} + +// +// Utilities +// + +namespace detail { + +// Get just the Swizzle part of a composed layout. +template +CUTE_HOST_DEVICE constexpr +auto +get_swizzle_portion(ComposedLayout,Offset,LayoutB>) +{ + return Swizzle{}; +} + +// A non-swizzled layout's "Swizzle part" is the identity swizzle. +template +CUTE_HOST_DEVICE constexpr +auto +get_swizzle_portion(Layout) +{ + return Swizzle<0,4,3>{}; +} + +// Get the "non-swizzle" part of a composed layout, +// which is the underlying (non-composed) Layout. +template +CUTE_HOST_DEVICE constexpr +auto +get_nonswizzle_portion(ComposedLayout,Offset,LayoutB> const& slayout) +{ + return slayout.layout_b(); +} + +// The non-swizzle part of a non-swizzled layout is just the Layout. +template +CUTE_HOST_DEVICE constexpr +auto +get_nonswizzle_portion(Layout const& slayout) +{ + return slayout; +} + +} // namespace detail + +// +// Slice a Swizzled ComposedLayout +// + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +make_swizzle_strides(true_type, + IntZ const& Z, + IntY const& Y, + Offset const& offset, + int_sequence) +{ + // Below is an optimized/compressed version of: + //return cute::make_tuple((swizzle(offset + Z*Int<(1 << I)>{}) - swizzle(offset))...); + // with knowledge of Swizzle, I... ranges for each B bits, + // and the layout won't slice along z-bits that are already set + + // y\z 0 1 + // 0 Z DC + // 1 -Z DC + + return cute::make_tuple(conditional_return((offset & (Y << Int{})) == Int<0>{}, Z * Int<(1 << I)>{}, -Z * Int<(1 << I)>{})...); +} + +template +CUTE_HOST_DEVICE constexpr +auto +make_swizzle_strides(false_type, + IntZ const& Z, + IntY const& Y, + Offset const& offset, + int_sequence) +{ + // Below is an optimized/compressed version of: + //return cute::make_tuple((swizzle(offset + Y*Int<(1 << I)>{}) - swizzle(offset))...); + // with knowledge of Swizzle, I... ranges for each B bits, + // and the layout won't slice along y-bits that are already set + + // y\z 0 1 + // 0 Y+Z Y-Z + // 1 DC DC + + return cute::make_tuple(conditional_return((offset & (Z << Int{})) == Int<0>{}, (Y+Z) * Int<(1 << I)>{}, (Y-Z) * Int<(1 << I)>{})...); +} + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr +auto +slice_and_offset(Coord const& coord, ComposedLayout,Offset,Layout> const& layout) +{ + if constexpr (all_underscore::value) { + // Skip the expensive/complicated attempt to decay to a normal layout and just reshape + return cute::make_tuple(composition(layout.layout_a(), layout.offset(), slice(coord, layout.layout_b())), Int<0>{}); + } else { + + // Projections of the swizzle layout for composition + auto sw = make_layout(make_shape(Int<(1 << M)>{}, Int<(1 << B)>{}, Int<(1 << (abs(S)-B))>{}, Int<(1 << B)>{}, Int<1>{})); + + auto swizzle_anti_zy = make_layout(shape(sw), + make_stride(stride<0>(sw), Int<0>{}, stride<2>(sw), Int<0>{}, size(sw))); + auto swizzle_only_zy = make_layout(shape(sw), + make_stride( Int<0>{}, stride<1>(sw), Int<0>{}, stride<3>(sw), Int<0>{})); + + // The portion of the layout that is not yet consumed + auto sliced_layout = slice(coord, layout.layout_b()); + + // The portion of the layout that we are consuming now + auto diced_layout = dice(coord, layout.layout_b()); + auto diced_coord = dice(coord, coord); + + auto diced_layout_anti_zy = composition(swizzle_anti_zy, diced_layout); + auto diced_layout_only_zy = composition(swizzle_only_zy, diced_layout); + + // New swizzle and offset + auto swizzle = layout.layout_a(); + // offset_only_zy interacts with swizzle and gets accumulated with layout.offset() + // being careful about the static/dynamic contributions from diced_layout and diced_coord + auto offset_only_zy = layout.offset() ^ to_mixed_bits(diced_layout_only_zy, diced_coord); + // offset_anti_zy always gets passed through, no interaction with swizzle + auto offset_anti_zy = diced_layout_anti_zy(diced_coord); + + // If Layout's codomain hits on Y AND Z, then it's not reducible + // If Layout's codomain hits on Y XOR Z, then it's dynamic-normal + // If Layout's codomain hits on neither Y NOR Z, then it's static-normal + + // If the sliced_layout hits two bits that are swizzled together, then don't attempt to decay + + // Compose with the layout to get the swizzle projection, P o L [The Z and Y contributing portions of L] + // (this also tests that shape/stride of layout compose with swizzle) + auto sliced_layout_only_zy = composition(swizzle_only_zy, sliced_layout); + // Transform the end coordinate to get the active bits of the swizzle, (P o L)(c*) + [[maybe_unused]] auto swizzle_active_bits = sliced_layout_only_zy(size(sliced_layout_only_zy)-Int<1>{}); + + // Determine if any active bits collide under the swizzle for potential decay + if constexpr (is_constant<0, decltype(not (swizzle_active_bits & ~swizzle(swizzle_active_bits)))>::value) + { // Hits on Y AND Z, so it's not reducible + return cute::make_tuple(composition(swizzle, offset_only_zy, sliced_layout), offset_anti_zy); + } else + { // Misses on Y or Z, so it's static-normal or dynamic-normal + + // Lowest bit of the Z and Y masks + auto Z = typename Swizzle::zzz_msk{} & -typename Swizzle::zzz_msk{}; + auto Y = typename Swizzle::yyy_msk{} & -typename Swizzle::yyy_msk{}; + auto stride_lo = detail::make_swizzle_strides(Z < Y, Z, Y, offset_only_zy, make_int_sequence{}); + auto stride_hi = detail::make_swizzle_strides(Z > Y, Z, Y, offset_only_zy, make_int_sequence{}); + + // Construct a (dynamic) layout that we can perform the composition with + auto swizzle_layout = make_layout(make_shape (Int<(1 << M)>{}, repeat(Int<2>{}), Int<(1 << (abs(S)-B))>{}, repeat(Int<2>{}), Int< 1>{}), + make_stride(Int< 1>{}, stride_lo, Int<(1 << (M+B))>{}, stride_hi , Int<(1 << (M+B+abs(S)))>{})); + + // Decay to a normal layout with offset + return cute::make_tuple(composition(swizzle_layout, sliced_layout), + swizzle(offset_only_zy) + offset_anti_zy); + } + } + + CUTE_GCC_UNREACHABLE; +} + +// +// composition +// + +// Ignore identity case +template +CUTE_HOST_DEVICE constexpr +auto +composition(Swizzle<0,M,S> const&, + Int<0> const&, + Layout const& layout) +{ + return layout; +} + +template +CUTE_HOST_DEVICE constexpr +auto +composition(Swizzle const& sxor, + Layout const& layout) +{ + return composition(sxor, Int<0>{}, layout); +} + +template +CUTE_HOST_DEVICE constexpr +auto +composition(Layout const& a, + Swizzle const& b) +{ + // Get the Z bits and the Y bits + auto active_Y = a(typename Swizzle::yyy_msk{}); + auto active_Z = a(typename Swizzle::zzz_msk{}); + + // Works in simple cases... but could be greatly generalized + + return composition(make_swizzle(), a); +} + +// +// inverse +// + +// Specialization to attempt to pass-through the Swizzle back to the left -- Needed? +template +CUTE_HOST_DEVICE constexpr +auto +right_inverse(ComposedLayout,Offset,Layout> const& layout) +{ + if constexpr (is_constant<0, Offset>::value) { + return composition(right_inverse(layout.layout_b()), layout.layout_a()); + } else { + return composition(right_inverse(layout.layout_b()), right_inverse(layout.offset()), right_inverse(layout.layout_a())); + } +} + +// Specialization to attempt to pass-through the Swizzle back to the left -- Needed? +template +CUTE_HOST_DEVICE constexpr +auto +left_inverse(ComposedLayout,Offset,Layout> const& layout) +{ + if constexpr (is_constant<0, Offset>::value) { + return composition(left_inverse(layout.layout_b()), layout.layout_a()); + } else { + return composition(left_inverse(layout.layout_b()), left_inverse(layout.offset()), left_inverse(layout.layout_a())); + } +} + +template +CUTE_HOST_DEVICE constexpr +Swizzle +right_inverse(Swizzle const& sw) +{ + return sw; +} + +template +CUTE_HOST_DEVICE constexpr +Swizzle +left_inverse(Swizzle const& sw) +{ + return sw; +} + +// Kludge -- Probably want an OffsetFn here instead +template ::value)> +CUTE_HOST_DEVICE constexpr +auto +right_inverse(T const& t) +{ + return -t; +} + +// Kludge -- Probably want an OffsetFn here instead +template ::value)> +CUTE_HOST_DEVICE constexpr +auto +left_inverse(T const& t) +{ + return -t; +} + +// +// Upcast and Downcast +// + +template +CUTE_HOST_DEVICE constexpr +auto +upcast(Swizzle const& swizzle) +{ + static_assert(has_single_bit(N), "N must be a power of two"); + constexpr int log2_n = bit_width(uint32_t(N)) - 1; + constexpr int NewM = M - log2_n; + if constexpr (NewM >= 0) { + return Swizzle{}; + } else { + return Swizzle{}; + } + + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +downcast(Swizzle const& swizzle) +{ + static_assert(has_single_bit(N), "N must be a power of two"); + constexpr int log2_n = bit_width(uint32_t(N)) - 1; + return Swizzle{}; +} + +template +CUTE_HOST_DEVICE constexpr +auto +recast_layout(Swizzle const& swizzle) +{ + using scale = decltype(trait_ratio(sizeof_bits{}, sizeof_bits{})); + if constexpr (scale::num == 1 && scale::den == 1) { + return swizzle; + } + else if constexpr (scale::num == 1) { + return downcast(swizzle); + } + else if constexpr (scale::den == 1) { + return upcast(swizzle); + } + else { + return downcast(upcast(layout)); + } + CUTE_GCC_UNREACHABLE; +} + +template +CUTE_HOST_DEVICE constexpr +auto +max_alignment(Swizzle const&) +{ + return Int<(1 << M)>{}; +} + +template +CUTE_HOST_DEVICE constexpr +auto +max_alignment(ComposedLayout,Offset,LayoutB> const& layout) +{ + return gcd(max_alignment(layout.layout_a()), + max_alignment(layout.offset()), + max_alignment(layout.layout_b())); +} + +// +// Other operations +// + +template +CUTE_HOST_DEVICE constexpr +auto +max_common_layout(ComposedLayout,Offset,LayoutB> const& a, + Layout const& b) +{ + auto common = max_common_layout(a.layout_b(), b); + auto base = Int<(1 << M)>{}; + if constexpr (base < size(common)) { + return common.compose(base); // Truncate common to size base + } else { + return common; + } +} + +template +CUTE_HOST_DEVICE constexpr +auto +max_common_layout(Layout const& a, + ComposedLayout,Offset,LayoutB> const& b) +{ + return max_common_layout(b, a); +} + +template +CUTE_HOST_DEVICE constexpr +auto +max_common_vector(ComposedLayout,Offset,LayoutB> const& a, + Layout const& b) +{ + // This assumes that Offset is in the YZ domain of the Swizzle... + return cute::min(max_common_vector(a.layout_b(), b), Int<(1 << M)>{}); +} + +template +CUTE_HOST_DEVICE constexpr +auto +max_common_vector(Layout const& a, + ComposedLayout,Offset,LayoutB> const& b) +{ + return max_common_vector(b, a); +} + +template +CUTE_HOST_DEVICE constexpr +auto +max_common_vector(ComposedLayout,Offset0,LayoutB0> const& a, + ComposedLayout,Offset1,LayoutB1> const& b) +{ + // Typical impl is composition(a, right_inverse(b)) + // so this is Sw0 o B0 o rinv(Sw1 o B1) = Sw0 o B0 o rinv(B1) o Sw1 + auto vec = max_common_vector(a.layout_b(), b.layout_b()); + + // This assumes that Offset is in the YZ domain of the Swizzle... + if constexpr (Swizzle{} == Swizzle{}) { + return vec; + } else { + return cute::min(vec, Int<(1 << M0)>{}, Int<(1 << M1)>{}); + } + + CUTE_GCC_UNREACHABLE; +} + +/////////////////////////////////////////////////////////////////////////////// +// ComposedLayout as second argument is often more difficult... + +template +CUTE_HOST_DEVICE constexpr +auto +logical_product(Layout const& layout, + ComposedLayout,Offset,LayoutT> const& tiler) +{ + CUTE_STATIC_ASSERT_V(tiler.offset() == Int<0>{}, "Require Swizzle offset == 0."); + // The new layout -- if swizzle wasn't an issue, this is the result + // our goal is to determine a new swizzle for these strides + auto new_layout = logical_product(layout, tiler.layout_b()); + + // This is accomplished by identifying + // S o L :=: S? o L* + // We identify the "active" portion of S by computing (P o L)(c*) where P is a projection generated by S + // Then that active identifier is transformed through the layouts: + // L*(L[(P o L)(c*)]) + // which is a new swizzle identifier for S?, the new swizzle + + // Projections of the swizzle layout for composition, P + auto swizzle_only_zy = make_layout(make_shape (Int<(1 << M)>{}, Int<(1 << B)>{}, Int<(1 << (abs(S)-B))>{}, Int<(1 << B )>{}, Int<1>{}), + make_stride( Int<0>{}, Int<(1 << M)>{}, Int<0>{}, Int<(1 << (M+abs(S)))>{}, Int<0>{})); + + // Compose with the tiler to get the swizzle projection, P o L [The Z and Y contributing portions of L] + auto layout_only_zy = composition(swizzle_only_zy, tiler.layout_b()); + // Transform the end coordinate to get the active bits of the swizzle, (P o L)(c*) + auto swizzle_active_bits = layout_only_zy(size(layout_only_zy)-Int<1>{}); + // Get the Z bit and the Y bits + auto active_Z = swizzle_active_bits & typename Swizzle::zzz_msk{}; + auto active_Y = swizzle_active_bits & typename Swizzle::yyy_msk{}; + + // Pass the identifiers through the old layout and new layout to make a new swizzle identifier, L*(L[(P o L)(c*)]) + auto new_active_Z = new_layout(Int<0>{}, tiler.layout_b()[active_Z]); + auto new_active_Y = new_layout(Int<0>{}, tiler.layout_b()[active_Y]); + + // Use this new swizzle identifier to construxt the new swizzle for new_layout + // (this also makes sure it's a "valid" swizzle that Swizzle can represent) + return composition(make_swizzle(), new_layout); +} + +} // end namespace cute diff --git a/include/cute/tensor.hpp b/include/cute/tensor.hpp new file mode 100644 index 0000000000..3f3335b63d --- /dev/null +++ b/include/cute/tensor.hpp @@ -0,0 +1,58 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +// +// Extended Engines +// + +#include +#include +#include +#include + +// +// Tensor Algorithms +// + +#include +#include +#include +#include +#include +#include +#include + +#include +#include + diff --git a/include/cute/tensor_impl.hpp b/include/cute/tensor_impl.hpp new file mode 100644 index 0000000000..2be19c15e3 --- /dev/null +++ b/include/cute/tensor_impl.hpp @@ -0,0 +1,1206 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief This file contains the definition of Tensor as well as classes/functions most closely associated with it. + + For backwards-compatibility, "tensor.hpp" is the "entrypoint" header for a collection of classes and utilities + that are adjacent to Tensor, e.g. fill(). Whereas this file contains the actual definition of Tensor and + a small set of functions central to its usage. + + Within the CUTLASS codebase, favor not including "tensor.hpp" wherever possible; instead include "tensor_impl.hpp" + along with other specific headers that you need. This helps to avoid circular includes and to reduce build time. +*/ + +#pragma once + +#include // CUTE_HOST_DEVICE +#include // cute::Shape +#include // cute::is_composed_layout +#include // cute::recast_ptr +#include // cute::iterator_traits +#include // cute::array_aligned +#include // cute::array_subbyte +#include // cute::tuple +#include // cute::is_integral +#include // __CUTE_REQUIRES + +namespace cute +{ + +// +// Engine -- owning or non-owning data store +// + +// concept Engine { +// using iterator = ; +// using value_type = ; +// using element_type = ; +// using reference = ; +// iterator begin(); +// }; + +template +struct ArrayEngine +{ + using Storage = typename conditional<(sizeof_bits::value % 8 == 0), + array_aligned, + array_subbyte>::type; + using iterator = typename Storage::iterator; + using reference = typename iterator_traits::reference; + using element_type = typename iterator_traits::element_type; + using value_type = typename iterator_traits::value_type; + Storage storage_; + + CUTE_HOST_DEVICE constexpr auto begin() const { return storage_.begin(); } + CUTE_HOST_DEVICE constexpr auto begin() { return storage_.begin(); } +}; + +// Specialization for sparse_elem tensor allocation/iteration +// NOTE: This can and should be used for allocation of SMEM as well! +// Fuse these two ArrayEngines? +template +struct ArrayEngine, N> +{ + static_assert(N % S == 0, "Expected a multiple of the sparsity."); + using value_type = sparse_elem; + using Storage = typename conditional<(sizeof_bits::value % 8 == 0), + array_aligned, + array_subbyte>::type; + using iterator = sparse_ptr*>; + using reference = typename iterator_traits::reference; + using element_type = typename iterator_traits::element_type; + Storage storage_; + + CUTE_HOST_DEVICE constexpr auto begin() const { return recast_ptr(storage_.begin()); } + CUTE_HOST_DEVICE constexpr auto begin() { return recast_ptr(storage_.begin()); } +}; + +template +struct ViewEngine +{ + using iterator = Iterator; + using reference = typename iterator_traits::reference; + using element_type = typename iterator_traits::element_type; + using value_type = typename iterator_traits::value_type; + iterator storage_; + + CUTE_HOST_DEVICE constexpr iterator const& begin() const { return storage_; } + CUTE_HOST_DEVICE constexpr iterator & begin() { return storage_; } +}; + +template +struct ConstViewEngine +{ + using iterator = Iterator; + using reference = typename iterator_traits::reference; + using element_type = typename iterator_traits::element_type; + using value_type = typename iterator_traits::value_type; + iterator storage_; + + CUTE_HOST_DEVICE constexpr iterator const& begin() const { return storage_; } +}; + +// +// Tensor +// + +template +struct Tensor +{ + using iterator = typename Engine::iterator; + using value_type = typename Engine::value_type; + using element_type = typename Engine::element_type; + using reference = typename Engine::reference; + + using engine_type = Engine; + using layout_type = Layout; + + CUTE_HOST_DEVICE constexpr + Tensor() {} + + CUTE_HOST_DEVICE constexpr + Tensor(Engine const& engine, Layout const& layout) + : rep_(layout, engine) { + } + + // + // Accessors + // + + static constexpr int rank = Layout::rank; + + CUTE_HOST_DEVICE constexpr + decltype(auto) + tensor() const { + return *this; + } + + CUTE_HOST_DEVICE constexpr + decltype(auto) + engine() const { + return get<1>(rep_); + } + + CUTE_HOST_DEVICE constexpr + decltype(auto) + engine() { + return get<1>(rep_); + } + + CUTE_HOST_DEVICE constexpr + decltype(auto) + data() const { + return engine().begin(); + } + + CUTE_HOST_DEVICE constexpr + decltype(auto) + data() { + return engine().begin(); + } + + CUTE_HOST_DEVICE constexpr + decltype(auto) + layout() const { + return get<0>(rep_); + } + + CUTE_HOST_DEVICE constexpr + decltype(auto) + shape() const { + return layout().shape(); + } + + CUTE_HOST_DEVICE constexpr + auto + size() const { + return cute::size(shape()); + } + + CUTE_HOST_DEVICE constexpr + decltype(auto) + stride() const { + return layout().stride(); + } + + // + // Indexing op() and op[] + // + + // Index into this tensor like an array by computing the offset via layout() + template + CUTE_HOST_DEVICE constexpr + decltype(auto) + operator[](Coord const& coord) { + return data()[layout()(coord)]; + } + + template + CUTE_HOST_DEVICE constexpr + decltype(auto) + operator[](Coord const& coord) const { + return data()[layout()(coord)]; + } + + template + CUTE_HOST_DEVICE constexpr + decltype(auto) + operator()(Coord const& coord) { + if constexpr (has_underscore::value) { + auto const& [sliced_layout,offset] = slice_and_offset(coord, layout()); + return make_tensor(data() + offset, sliced_layout); + } else { + return data()[layout()(coord)]; + } + + CUTE_GCC_UNREACHABLE; + } + + template + CUTE_HOST_DEVICE constexpr + decltype(auto) + operator()(Coord const& coord) const { + if constexpr (has_underscore::value) { + auto const& [sliced_layout,offset] = slice_and_offset(coord, layout()); + return make_tensor(data() + offset, sliced_layout); + } else { + return data()[layout()(coord)]; + } + + CUTE_GCC_UNREACHABLE; + } + + // op() convenience function for multi-dimensional coordinates + template + CUTE_HOST_DEVICE constexpr + decltype(auto) + operator()(Coord0 const& c0, Coord1 const& c1, Coords const&... cs) { + return operator()(make_coord(c0,c1,cs...)); + } + + template + CUTE_HOST_DEVICE constexpr + decltype(auto) + operator()(Coord0 const& c0, Coord1 const& c1, Coords const&... cs) const { + return operator()(make_coord(c0,c1,cs...)); + } + + // + // Compose + // + + template + CUTE_HOST_DEVICE constexpr + auto + compose(Layouts const&... layouts) { + return make_tensor(data(), layout().compose(layouts...)); + } + + template + CUTE_HOST_DEVICE constexpr + auto + compose(Layouts const&... layouts) const { + return make_tensor(data(), layout().compose(layouts...)); + } + + // + // Tile + // + + template + CUTE_HOST_DEVICE constexpr + auto + tile(Layouts const&... layouts) { + return make_tensor(data(), layout().tile(layouts...)); + } + + template + CUTE_HOST_DEVICE constexpr + auto + tile(Layouts const&... layouts) const { + return make_tensor(data(), layout().tile(layouts...)); + } + + // + // Utility + // + + template ::value)> + CUTE_HOST_DEVICE constexpr + auto + get_1d_coord(Int const& linear_idx) const { + return layout().get_1d_coord(linear_idx); + } + + template ::value)> + CUTE_HOST_DEVICE constexpr + auto + get_hier_coord(Int const& linear_idx) const { + return layout().get_hier_coord(linear_idx); + } + + template ::value)> + CUTE_HOST_DEVICE constexpr + auto + get_flat_coord(Int const& linear_idx) const { + return layout().get_flat_coord(linear_idx); + } + + cute::tuple rep_; +}; + +template +struct is_tensor : false_type {}; +template +struct is_tensor> : true_type {}; +template +constexpr bool is_tensor_v = is_tensor::value; + +// Customization point for creation of owning and non-owning Tensors +template +struct MakeTensor +{ + template + CUTE_HOST_DEVICE constexpr auto + operator()(Arg0 const& arg0, Args const&... args) const + { + if constexpr (has_dereference::value) { + // Construct a non-owning Tensor + using Engine = ViewEngine; + if constexpr (sizeof...(Args) == 1 && (is_layout::value && ...)) { + // Forward a Layout + return Tensor{Engine{arg0}, args...}; + } else { + // Construct a Layout from Args + return Tensor{Engine{arg0}, make_layout(args...)}; + } + } else { + // Construct an owning Tensor + static_assert((is_static::value && ... && is_static::value), + "Dynamic owning tensors not supported"); + if constexpr (sizeof...(Args) == 0 && is_layout::value) { + // Forward a Layout + using Layout = Arg0; + using Engine = ArrayEngine>; + return Tensor(); + } else { + // Construct a Layout from Args + using Layout = decltype(make_layout(arg0, args...)); + using Engine = ArrayEngine>; + return Tensor(); + } + } + } +}; + +// +// make_tensor +// + +// Make an owning Tensor that will allocate a static array +// e.g. make_tensor(Int<12>{}) +template +CUTE_HOST_DEVICE constexpr +auto +make_tensor(Args const&... args) +{ + static_assert((not has_dereference::value && ...), "Expected layout args... in make_tensor(args...)"); + return MakeTensor{}(args...); +} + +// Make a non-owning Tensor that will use a pointer (view) +// e.g. make_tensor(vec.data(), 12) +template +CUTE_HOST_DEVICE constexpr +auto +make_tensor(Iterator const& iter, Args const&... args) +{ + static_assert(has_dereference::value, "Expected iterator iter in make_tensor(iter, args...)"); + static_assert((not has_dereference::value && ...), "Expected layout args... in make_tensor(iter, args...)"); + return MakeTensor{}(iter, args...); +} + +// +// make_tensor_like +// Make a register tensor the same type and shape and (if possible) order as another tensor +// + +template +CUTE_HOST_DEVICE constexpr +auto +make_tensor_like(Layout const& layout) +{ + return make_tensor(make_layout_like(layout)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +make_tensor_like(Tensor const& tensor) +{ + return make_tensor_like(tensor.layout()); +} + +template +CUTE_HOST_DEVICE constexpr +auto +make_tensor_like(Tensor const& tensor) +{ + return make_tensor_like(tensor.layout()); +} + +// +// make_fragment_like +// Make a tensor the same shape and (if possible) order as another tensor, with special +// consideration of the 0th mode. The 0th mode is commonly used for MMA_Atoms or Copy_Atoms +// so this allocates the 0th mode with LayoutLeft regardless of the reference layout. +// + +template +CUTE_HOST_DEVICE constexpr +auto +make_fragment_like(Layout const& layout) +{ + return make_tensor(make_fragment_like(layout)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +make_fragment_like(Tensor const& tensor) +{ + return make_fragment_like(tensor.layout()); +} + +template +CUTE_HOST_DEVICE constexpr +auto +make_fragment_like(Tensor const& tensor) +{ + return make_fragment_like(tensor.layout()); +} + +// +// make_counting_tensor +// Make a tensor from a layout by binding it to a counting iter with 0-offset of the same profile as the codomain. +// + +template ::value)> +CUTE_HOST_DEVICE constexpr +auto +make_counting_tensor(Layout const& layout) +{ + return make_tensor(make_inttuple_iter(repeat_like(coshape(layout), Int<0>{})), layout); +} + +// +// make_identity_tensor +// Make a tensor that maps coordinates within a shape to themselves. +// + +template +CUTE_HOST_DEVICE constexpr +auto +make_identity_tensor(Shape const& shape) +{ + return make_counting_tensor(make_identity_layout(shape)); +} + +// +// Utilities +// + +// Return the subtensor of a mode +template +CUTE_HOST_DEVICE constexpr +auto +tensor(Tensor&& tensor) +{ + if constexpr (sizeof...(Is) == 0) { + return tensor; + } else { + return make_tensor(tensor.data(), get(tensor.layout())); + } + + CUTE_GCC_UNREACHABLE; +} + +// Return the layout of a mode +template +CUTE_HOST_DEVICE constexpr +auto +layout(Tensor const& tensor) +{ + return layout(tensor.layout()); +} + +// Return the shape of a mode +template +CUTE_HOST_DEVICE constexpr +auto +shape(Tensor const& tensor) +{ + return shape(tensor.layout()); +} + +// Return the stride of a mode +template +CUTE_HOST_DEVICE constexpr +auto +stride(Tensor const& tensor) +{ + return stride(tensor.layout()); +} + +// Return the number of elements in a mode +template +CUTE_HOST_DEVICE constexpr +auto +size(Tensor const& tensor) +{ + return size(tensor.layout()); +} + +// Return the rank of a mode +template +CUTE_HOST_DEVICE constexpr +auto +rank(Tensor const& tensor) +{ + return rank(tensor.layout()); +} + +// Return the depth of a mode +template +CUTE_HOST_DEVICE constexpr +auto +depth(Tensor const& tensor) +{ + return depth(tensor.layout()); +} + +// +// Operations to manipulate Tensors like a Layout or IntTuple +// These are implemented with explicit modifier overloads because these +// methods likely also have a general IntTuple overload that can shadow. +// + +template +CUTE_HOST_DEVICE constexpr +auto +flatten(Tensor const& tensor) { + return make_tensor(tensor.data(), flatten(tensor.layout())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +flatten(Tensor& tensor) { + return make_tensor(tensor.data(), flatten(tensor.layout())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +flatten(Tensor&& tensor) { + return make_tensor(tensor.data(), flatten(tensor.layout())); +} + +template > +CUTE_HOST_DEVICE constexpr +auto +coalesce(Tensor const& tensor, Profile const& profile = {}) { + return make_tensor(tensor.data(), coalesce(tensor.layout(), profile)); +} + +template > +CUTE_HOST_DEVICE constexpr +auto +coalesce(Tensor& tensor, Profile const& profile = {}) { + return make_tensor(tensor.data(), coalesce(tensor.layout(), profile)); +} + +template > +CUTE_HOST_DEVICE constexpr +auto +coalesce(Tensor&& tensor, Profile const& profile = {}) { + return make_tensor(tensor.data(), coalesce(tensor.layout(), profile)); +} + +// Replace the modes in layout that have a 0-stride with a 1-size +template +CUTE_HOST_DEVICE constexpr +auto +filter_zeros(Tensor const& tensor) { + return make_tensor(tensor.data(), filter_zeros(tensor.layout())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +filter_zeros(Tensor& tensor) { + return make_tensor(tensor.data(), filter_zeros(tensor.layout())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +filter_zeros(Tensor&& tensor) { + return make_tensor(tensor.data(), filter_zeros(tensor.layout())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +filter_zeros(Tensor const& tensor, Profile const& profile) +{ + return make_tensor(tensor.data(), filter_zeros(tensor.layout(), profile)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +filter_zeros(Tensor& tensor, Profile const& profile) +{ + return make_tensor(tensor.data(), filter_zeros(tensor.layout(), profile)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +filter_zeros(Tensor&& tensor, Profile const& profile) +{ + return make_tensor(tensor.data(), filter_zeros(tensor.layout(), profile)); +} + +// Remove all of the 0-strides and 1-sizes +template +CUTE_HOST_DEVICE constexpr +auto +filter(Tensor const& tensor) { + return make_tensor(tensor.data(), filter(tensor.layout())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +filter(Tensor& tensor) { + return make_tensor(tensor.data(), filter(tensor.layout())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +filter(Tensor&& tensor) { + return make_tensor(tensor.data(), filter(tensor.layout())); +} + +// Group the modes [B,E) into a single mode +// e.g. group<2,4>(make_tensor(Layout>{})) +// => make_tensor(Layout,_5,_6>>{}) +template +CUTE_HOST_DEVICE constexpr +auto +group_modes(Tensor const& tensor) { + return make_tensor(tensor.data(), group(tensor.layout())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +group_modes(Tensor& tensor) { + return make_tensor(tensor.data(), group(tensor.layout())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +group_modes(Tensor&& tensor) { + return make_tensor(tensor.data(), group(tensor.layout())); +} + +// Return the subtensor of a range of modes +template +CUTE_HOST_DEVICE constexpr +auto +take(Tensor const& tensor) { + return make_tensor(tensor.data(), take(tensor.layout())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +take(Tensor& tensor) { + return make_tensor(tensor.data(), take(tensor.layout())); +} + +template +CUTE_HOST_DEVICE constexpr +auto +take(Tensor&& tensor) { + return make_tensor(tensor.data(), take(tensor.layout())); +} + +// Return a tensor with the same shape as input but offset by a given coordinate +template >::value)> +CUTE_HOST_DEVICE constexpr +auto +domain_offset(Coord const& coord, Tensor&& tensor) +{ + auto [layout, ptr_offset] = domain_offset(coord, tensor.layout()); + return make_tensor(static_cast(tensor).data() + ptr_offset, layout); +} + +// +// Recast +// + +// NOTE: This is very dangerous to do +// -- doesn't check dynamic integer divisibility +// -- doesn't check alignment + +template +CUTE_HOST_DEVICE constexpr +auto +recast(Tensor&& tensor) +{ + using OldType = typename remove_cvref_t::value_type; + auto old_layout = tensor.layout(); + auto new_layout = recast_layout(old_layout); + + // If this is an upcast of a normal Layout with static negative strides, then offset as well + if constexpr (sizeof(OldType) < sizeof(NewType) && not is_composed_layout::value) { + auto shape_diff = transform(flatten(old_layout.shape()), flatten(new_layout.shape()), minus{}); + auto extent_diff = transform(shape_diff, flatten(old_layout.stride()), multiplies{}); + auto offset = fold(extent_diff, Int<0>{}, [](auto const& i, auto const& a) { return i + cute::min(a,Int<0>{}); }); + + return make_tensor(recast_ptr(static_cast(tensor).data() + offset), new_layout); + } else { + return make_tensor(recast_ptr(static_cast(tensor).data() ), new_layout); + } + + CUTE_GCC_UNREACHABLE; +} + +// +// max_common_vector +// + +/* Return Int such that N is the maximum number of contiguous elements + * that logically correspond in the tensors of @a a and @a b. This is, + * the number of elements that could reasonably be vectorized into a single load/store. + * + * @returns Int with N >= 0 + * + * A return value of Int<0> indicates that no such conclusion can be made and no + * vectorization should be attempted. + * + * Note that the return value does NOT include alignment concerns such as the pointer value and + * the divisbility of dynamic strides. + */ +template +CUTE_HOST_DEVICE constexpr +auto +max_common_vector(Tensor const& a, + Tensor const& b) +{ + using SrcType = typename SrcEngine::value_type; + using SrcRef = typename SrcEngine::reference; + using DstType = typename DstEngine::value_type; + using DstRef = typename DstEngine::reference; + + // Determine if vectorization candidates at all + if constexpr (// Should be the same value_types, else the copy is also performing a cast + cute::is_same::value && + // The types should be trivially copyable so that vectorization is valid + is_trivially_copyable::value && + is_trivially_copyable::value && + // Should be load/storing real data, rather than implicit iterators or such + is_reference::value && + is_reference::value) + { + return max_common_vector(a.layout(), b.layout()); + } else { + return Int<0>{}; + } + + CUTE_GCC_UNREACHABLE; +} + +/* Return a layout that points to the maximum number of contiguous elements + * that logically correspond in the tensors of @a a and @a b. This is, + * the elements that could reasonably be "vectorized" into a single load/store. + * + * @returns Layout R such that composition(a.layout(), R) and composition(b.layout(), R) + * are both identity Layouts. + * + * Note that the returned layout does NOT include alignment concerns such as the pointer value and + * the divisbility of dynamic strides. + */ +template +CUTE_HOST_DEVICE constexpr +auto +max_common_layout(Tensor const& a, + Tensor const& b) +{ + using SrcType = typename SrcEngine::value_type; + using SrcRef = typename SrcEngine::reference; + using DstType = typename DstEngine::value_type; + using DstRef = typename DstEngine::reference; + + // Determine if vectorization candidates at all + if constexpr (// Should be the same value_types, else the copy is also performing a cast + cute::is_same::value && + // The types should be trivially copyable so that vectorization is valid + is_trivially_copyable::value && + is_trivially_copyable::value && + // Should be load/storing real data, rather than implicit iterators or such + is_reference::value && + is_reference::value) + { + return max_common_layout(a.layout(), b.layout()); + } else { + return Layout<_1,_0>{}; + } + + CUTE_GCC_UNREACHABLE; +} + +/* Return the maximum (statically known) alignment of a Tensor in the number of bits + */ +template +CUTE_HOST_DEVICE constexpr +auto +max_alignment(Tensor const& t) +{ + return gcd(max_alignment(t.data()), + max_alignment(t.layout()) * static_value>()); +} + +// +// Key algebraic operations -- Composition, Divide, and Product +// + +// Apply a Tiler to the Tensor via composition. +template >::value)> +CUTE_HOST_DEVICE constexpr +auto +composition(Tensor && tensor, + Tiler const& tiler) // Layout or Tile or Shape +{ + return make_tensor(static_cast(tensor).data(), + composition(tensor.layout(), tiler)); +} + +// Apply a Tiler to the Tensor. +// +// Consider a Tensor with shape (A,B,x,y) +// And a Tiler that is: +// +// * A Layout with shape (BLK_A,BLK_B) +// ** Result Tensor shape ((BLK_A,BLK_B),Rest). +// ** That is, the Tensor and Tile are treated as 1D for the tiling. +// ** See logical_divide(Layout,Layout) +// +// * A Tile with shape +// ** Result Tensor shape ((BLK_A,a),(BLK_B,b),x,y). +// ** Each mode of the Tile is applied to the corresponding mode of the Tensor. +// ** See logical_divide(Layout,Tuple) +// +// * A Shape (BLK_A,BLK_B) +// ** Result Tensor shape ((BLK_A,a),(BLK_B,b),x,y). +// ** Equivalent to applying Tile. +// ** See logical_divide(Layout,Tuple) and logical_divide(Layout,Int) +// +// Note that the Tile/Shape Tilers must be weakly_congruent to the Tensor +template >::value)> +CUTE_HOST_DEVICE constexpr +auto +logical_divide(Tensor && tensor, + Tiler const& tiler) // Layout or Tile or Shape +{ + return make_tensor(static_cast(tensor).data(), + logical_divide(tensor.layout(), tiler)); +} + +// zipped_divide is logical_divide with Tiler modes and Rest modes gathered together: (Tiler,Rest) +// When Tiler is Layout, this has no effect as logical_divide results in the same. +// When Tiler is Tile or Shape, this zips modes into standard form ((BLK_A,BLK_B),(a,b,x,y)) +template >::value)> +CUTE_HOST_DEVICE constexpr +auto +zipped_divide(Tensor && tensor, + Tiler const& tiler) // Layout or Tile or Shape +{ + return make_tensor(static_cast(tensor).data(), + zipped_divide(tensor.layout(), tiler)); +} + +// tiled_divide is zipped_divide with the second output mode flattened ((BLK_A,BLK_B),a,b,x,y) +template >::value)> +CUTE_HOST_DEVICE constexpr +auto +tiled_divide(Tensor && tensor, + Tiler const& tiler) // Layout or Tile or Shape +{ + return make_tensor(static_cast(tensor).data(), + tiled_divide(tensor.layout(), tiler)); +} + +// flat_divide is zipped_divide with the both modes flattened (BLK_A,BLK_B,a,b,x,y) +template >::value)> +CUTE_HOST_DEVICE constexpr +auto +flat_divide(Tensor && tensor, + Tiler const& tiler) // Layout or Tile or Shape +{ + return make_tensor(static_cast(tensor).data(), + flat_divide(tensor.layout(), tiler)); +} + +// logical_product on a Tensor doesn't make sense since it often increases cosize +// though this might make sense for creating Tensors with broadcasted (stride-0) modes + +// +// Tensor partitioning utilities +// + +// Apply a Tiler to the Tensor, then slice out one of those tiles by slicing into the "Rest" modes. +// With an inner_partition, you get everything that's inside the Tiler. Everything that the Tiler is pointing to. +// Split the modes of tensor according to the Tiler +// zipped_divide returns something like ((BLK_A,BLK_B,...),(a,b,...,x,y)) +// Then slice into the second mode (the "Rest" mode) with Coord +template >::value)> +CUTE_HOST_DEVICE constexpr +auto +inner_partition(Tensor && tensor, + Tiler const& tiler, + Coord const& coord) +{ + auto tensor_tiled = zipped_divide(static_cast(tensor), tiler); + constexpr int R0 = decltype(rank<0>(tensor_tiled))::value; + + // The coord slices into the second mode (the "rest" mode), flatten the first + if constexpr (is_tuple::value) { + // Append trailing modes if coord is tuple + constexpr int R1 = decltype(rank<1>(tensor_tiled))::value; + return tensor_tiled(repeat(_), append(coord,_)); + } else { + // Flat indexing if coord is not tuple + return tensor_tiled(repeat(_), coord); + } +} + +// Apply a Tiler to the Tensor, then slice out the remainder by slicing into the "Tile" modes. +// With an outer_partition, you get everything that's outside the Tiler. The layout of the Tile in the Tensor. +// Split the modes of tensor according to the Tiler +// zipped_divide returns something like ((BLK_A,BLK_B,...),(a,b,...,x,y)) +// Then slice into the first mode (the "Tile" mode) with Coord +template >::value)> +CUTE_HOST_DEVICE constexpr +auto +outer_partition(Tensor && tensor, + Tiler const& tiler, + Coord const& coord) +{ + auto tensor_tiled = zipped_divide(static_cast(tensor), tiler); + constexpr int R1 = decltype(rank<1>(tensor_tiled))::value; + + // The coord slices into the first mode (the "tile" mode), flatten the second + if constexpr (is_tuple::value) { + // Append trailing modes if coord is tuple + constexpr int R0 = decltype(rank<0>(tensor_tiled))::value; + return tensor_tiled(append(coord,_), repeat(_)); + } else { + // Flat indexing if coord is not tuple + return tensor_tiled(coord, repeat(_)); + } +} + +// Tile a tensor according to @a tiler and use @a coord to index into the remainder, keeping the tile. +// This is typical at the CTA level where tiles of data are extracted: +// Tensor data = ... // ( M, N) +// Tensor cta_data = local_tile(data, Shape<_32,_64>{}, make_coord(blockIdx.x,blockIdx.y)); // (_32,_64) +template >::value)> +CUTE_HOST_DEVICE constexpr +auto +local_tile(Tensor && tensor, + Tiler const& tiler, // tiler to apply + Coord const& coord) // coord to slice into "remainder" +{ + return inner_partition(static_cast(tensor), + tiler, + coord); +} + +// Same as above, but with a projection parameter to strip out unwanted tiling modes for convenience +// when using projections of the same tiler. +// This is typical at the CTA level where tiles of data are extracted as projections: +// Tensor dataA = ... // (M,K) +// Tensor dataB = ... // (N,K) +// Tensor dataC = ... // (M,N) +// auto cta_tiler = Shape<_32, _64, _4>{}; +// auto cta_coord = make_coord(blockIdx.x, blockIdx.y, _); +// Tensor ctaA = local_tile(dataA, cta_tiler, cta_coord, Step<_1, X,_1>{}); // (_32,_4,k) +// Tensor ctaB = local_tile(dataB, cta_tiler, cta_coord, Step< X,_1,_1>{}); // (_64,_4,k) +// Tensor ctaC = local_tile(dataC, cta_tiler, cta_coord, Step<_1,_1, X>{}); // (_32,_64) +template >::value)> +CUTE_HOST_DEVICE +auto +local_tile(Tensor && tensor, + Tiler const& tiler, // tiler to apply + Coord const& coord, // coord to slice into "remainder" + Proj const& proj) // projection to apply to tiler and coord +{ + return local_tile(static_cast(tensor), + dice(proj, tiler), + dice(proj, coord)); +} + +// Tile a tensor according to the flat shape of a layout that provides the coordinate of the target index. +// This is typical at the Thread level where data is partitioned across repeated patterns of threads: +// Tensor data = ... // (_16,_64) +// Tensor thr_data = local_partition(data, Layout>{}, thr_idx); // ( _8, _4) +template >::value)> +CUTE_HOST_DEVICE +auto +local_partition(Tensor && tensor, + Layout const& tile, // coord -> index + Index const& index) // index to slice for +{ + static_assert(is_integral::value); + return outer_partition(static_cast(tensor), + product_each(shape(tile)), + tile.get_flat_coord(index)); +} + +// Same as above, but with a projection parameter to strip out unwanted tiling modes for convenience +// when using projections of the same tiler. +// This is typical at the Thread level where data is partitioned across projected layouts of threads: +// Tensor dataA = ... // (M,K) +// Tensor dataB = ... // (N,K) +// Tensor dataC = ... // (M,N) +// auto thr_layout = Layout, Stride<_16,_1,_0>>{}; +// Tensor thrA = local_partition(dataA, thr_layout, thr_idx, Step<_1, X,_1>{}); // (M/2,K/1) +// Tensor thrB = local_partition(dataB, thr_layout, thr_idx, Step< X,_1,_1>{}); // (N/16,K/1) +// Tensor thrC = local_partition(dataC, thr_layout, thr_idx, Step<_1,_1, X>{}); // (M/2,N/16) +template >::value)> +CUTE_HOST_DEVICE +auto +local_partition(Tensor && tensor, + Layout const& tile, // coord -> index + Index const& index, // index to slice for + Projection const& proj) +{ + return local_partition(static_cast(tensor), + dice(proj, tile), + index); +} + +// +// Display utilities +// + +template +CUTE_HOST_DEVICE void print(Tensor const& tensor) +{ + print(tensor.data()); print(" o "); print(tensor.layout()); +} + +template +CUTE_HOST_DEVICE void print_tensor(Tensor const& tensor, bool print_type = true) +{ + if (print_type) { + print(tensor); print(":\n"); + } + + if constexpr (Layout::rank == 1) + { + for (int m = 0; m < size(tensor); ++m) { + pretty_print(tensor(m)); + printf("\n"); + } + } else + if constexpr (Layout::rank == 2) + { + for (int m = 0; m < size<0>(tensor); ++m) { + for (int n = 0; n < size<1>(tensor); ++n) { + pretty_print(tensor(m,n)); + } + printf("\n"); + } + } else + if constexpr (Layout::rank == 3) + { + print_tensor(tensor(_,_,0), false); + for (int k = 1; k < size<2>(tensor); ++k) { + for (int i = 0; i < 5*size<1>(tensor); ++i) { print("-"); } print("\n"); + print_tensor(tensor(_,_,k), false); + } + } else + if constexpr (Layout::rank == 4) + { + print_tensor(tensor(_,_,_,0), false); + for (int p = 1; p < size<3>(tensor); ++p) { + for (int i = 0; i < 5*size<1>(tensor); ++i) { print("="); } print("\n"); + print_tensor(tensor(_,_,_,p), false); + } + } +} + +#if !defined(__CUDACC_RTC__) +template +CUTE_HOST std::ostream& print_tensor_os(std::ostream& os, Tensor const& tensor) +{ + int digits = 9; + + if constexpr (Layout::rank == 1) + { + for (int m = 0; m < size(tensor); ++m) { + os << std::setw(digits) << tensor(m) << std::endl; + } + } else + if constexpr (Layout::rank == 2) + { + for (int m = 0; m < size<0>(tensor); ++m) { + for (int n = 0; n < size<1>(tensor); ++n) { + os << std::setw(digits) << tensor(m,n); + } + os << std::endl; + } + } else + if constexpr (Layout::rank == 3) + { + print_tensor_os(os, tensor(_,_,0)); + for (int k = 1; k < size<2>(tensor); ++k) { + for (int i = 0; i < digits*size<1>(tensor); ++i) { os << "-"; } os << std::endl; + print_tensor_os(os, tensor(_,_,k)); + } + } else + if constexpr (Layout::rank == 4) + { + print_tensor_os(os, tensor(_,_,_,0)); + for (int p = 1; p < size<3>(tensor); ++p) { + for (int i = 0; i < digits*size<1>(tensor); ++i) { os << "="; } os << std::endl; + print_tensor_os(os, tensor(_,_,_,p)); + } + } + + return os; +} + +template +CUTE_HOST std::ostream& operator<<(std::ostream& os, Tensor const& tensor) +{ + os << tensor.layout() << std::endl; + return print_tensor_os(os, tensor); +} +#endif // !defined(__CUDACC_RTC__) + +} // end namespace cute + diff --git a/include/cute/tensor_predicate.hpp b/include/cute/tensor_predicate.hpp new file mode 100644 index 0000000000..9c8a2ba614 --- /dev/null +++ b/include/cute/tensor_predicate.hpp @@ -0,0 +1,78 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTE_HOST_DEVICE +#include // cute::true_type + +namespace cute +{ + +template +struct ConstantTensor +{ + template + CUTE_HOST_DEVICE constexpr + T const& + operator()(Coords const&...) const { + return val_; + } + + T val_; +}; + +struct TrivialPredTensor +{ + template + CUTE_HOST_DEVICE constexpr + true_type + operator()(Coords const&...) const { + return {}; + } +}; + +template +struct FunctionPredTensor +{ + CUTE_HOST_DEVICE constexpr + FunctionPredTensor(Fn const& fn) : fn_(fn) {} + + template + CUTE_HOST_DEVICE constexpr + auto + operator()(Coords const&... coords) const { + return fn_(coords...); + } + + Fn const& fn_; +}; + +} // end namespace cute diff --git a/include/cute/tensor_zip.hpp b/include/cute/tensor_zip.hpp new file mode 100644 index 0000000000..6d70ffc847 --- /dev/null +++ b/include/cute/tensor_zip.hpp @@ -0,0 +1,243 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include // CUTE_HOST_DEVICE +#include // cute::Tensor +#include // cute::tuple + +namespace cute +{ + +// A tuple of Iterators that can be offset asymmetrically +// Note that this only accepts op+(tuple) and op[tuple] +// where each iterator will be offset by its respective index only. +// READ-ONLY for now until cute::tuple can be constructed with references. +template +struct ZipIterator +{ + using value_type = cute::tuple...>; + using element_type = cute::tuple...>; + // NOTE: cute::tuple does not support constructions with references at the moment. + // Consider fixes and/or an implementation of std::forward_as_tuple. + // For now, use a cute::tuple of value_types instead, which makes this Iterator READ-ONLY. + //using reference = cute::tuple...>; + using reference = value_type; + + ZipIterator() = delete; + + CUTE_HOST_DEVICE constexpr + ZipIterator(Iters... iters) + : iters_(iters...) + {} + + CUTE_HOST_DEVICE constexpr + ZipIterator(cute::tuple const& iters) + : iters_(iters) + {} + + CUTE_HOST_DEVICE constexpr + reference operator*() const { + return cute::apply(iters_, [](auto&&... args) { return reference(*args...); }); + } + + template + CUTE_HOST_DEVICE constexpr + ZipIterator operator+(cute::tuple const& idxs) const { + static_assert(sizeof...(Index) == sizeof...(Iters), "Expect same number of offsets as iterators."); + return cute::transform(iters_, idxs, [](auto&& iter, auto&& idx) { return iter + idx; }); + } + + template + CUTE_HOST_DEVICE constexpr + reference operator[](cute::tuple const& idxs) const { + return *(*this + idxs); + } + + cute::tuple iters_; +}; + +//------------------------------------------------------------------------------ +// type traits + +template +struct is_rmem> : conjunction...> {}; +template +struct is_smem> : conjunction...> {}; +template +struct is_gmem> : conjunction...> {}; +// A tuple of Layouts that operates on each Layout symmetrically +// The Layouts need to have compatible shapes and ranks. +// The ZipLayout presents the intersection of the domain of its component Layouts. +// E.g. all Layouts accept 1D coords and ZipLayout does as well. +// The ZipLayout returns the union of the codomain of its component Layouts. +// E.g. all Layouts return an integer so ZipLayout returns a tuple of integers. +template +struct ZipLayout +{ + static constexpr int rank = (int(0) | ... | Layouts::rank); + + static_assert((is_layout::value && ...), "All template parameters must be layouts"); + static_assert(((Layouts::rank == rank) && ...), "All layouts must have the same rank"); + + CUTE_HOST_DEVICE constexpr + ZipLayout(Layouts const&... layouts) + : layouts_(layouts...) + {} + + CUTE_HOST_DEVICE constexpr + ZipLayout(cute::tuple const& layouts) + : layouts_(layouts) + {} + + template + CUTE_HOST_DEVICE constexpr + auto + operator()(Coord const& coord) const { + if constexpr (has_underscore::value) { + return ZipLayout(cute::transform(layouts_, [&] (auto layout) { return layout(coord); })); + } else { + return cute::transform(layouts_, [&] (auto layout) { return layout(coord); }); + } + + CUTE_GCC_UNREACHABLE; + } + + // op() convenience function for multi-dimensional coordinates + template + CUTE_HOST_DEVICE constexpr + decltype(auto) + operator()(Coord0 const& c0, Coord1 const& c1, Coords const&... cs) const { + return operator()(make_coord(c0,c1,cs...)); + } + + cute::tuple layouts_; +}; + +template +struct is_layout> : true_type {}; + +// +// make_zip_tensor and unzip_tensor +// + +template +CUTE_HOST_DEVICE constexpr +auto +make_zip_tensor(Tensor const&... tensors) +{ + return make_tensor(ZipIterator(tensors.data()...), + ZipLayout(tensors.layout()...)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +unzip_tensor(Tensor const& tensor) +{ + return cute::transform(tensor.data().iters_, tensor.layout().layouts_, + [](auto iter, auto layout) { return make_tensor(iter, layout); }); +} + +// +// Utilities +// + +template +CUTE_HOST_DEVICE constexpr +auto +rank(ZipLayout const& layouts) +{ + return rank(get<0>(layouts.layouts_)); +} + +template +CUTE_HOST_DEVICE constexpr +auto +size(ZipLayout const& layouts) +{ + return size(get<0>(layouts.layouts_)); +} + +// +// Manipulation +// + +// Extend each component layout to rank-N by appending Layout @a x. +template +CUTE_HOST_DEVICE constexpr +auto +append(ZipLayout const& layouts, + Layout const& x = {}) +{ + return ZipLayout(cute::transform(layouts.layouts_, [&](auto t){ return append(t, x); })); +} + +// Extend each component layout to rank-N by prepending Layout @a x. +template +CUTE_HOST_DEVICE constexpr +auto +prepend(ZipLayout const& layouts, + Layout const& x = {}) +{ + return ZipLayout(cute::transform(layouts.layouts_, [&](auto t){ return prepend(t, x); })); +} + +template +CUTE_HOST_DEVICE constexpr +auto +logical_divide(ZipLayout const& layouts, + Tiler const& tiler) +{ + return ZipLayout(cute::transform(layouts.layouts_, [&](auto t){ return logical_divide(t, tiler); })); +} + +template +CUTE_HOST_DEVICE constexpr +auto +zipped_divide(ZipLayout const& layouts, + Tiler const& tiler) +{ + return ZipLayout(cute::transform(layouts.layouts_, [&](auto t){ return zipped_divide(t, tiler); })); +} + +// Return by calling slice_and_offset and all component layouts. +template +CUTE_HOST_DEVICE constexpr +auto +slice_and_offset(Coord const& c, ZipLayout const& layouts) +{ + auto result = cute::zip(cute::transform(layouts.layouts_, [&c](auto const& layout) { return slice_and_offset(c, layout); })); + return cute::make_tuple(ZipLayout(get<0>(result)), get<1>(result)); +} + +} // end namespace cute diff --git a/include/cute/underscore.hpp b/include/cute/underscore.hpp new file mode 100644 index 0000000000..e9d80fe5b5 --- /dev/null +++ b/include/cute/underscore.hpp @@ -0,0 +1,194 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTE_INLINE_CONSTANT, CUTE_HOST_DEVICE +#include // cute::is_tuple +#include // cute::false_type, cute::true_type + +namespace cute +{ + +// For slicing +struct Underscore : Int<0> {}; + +CUTE_INLINE_CONSTANT Underscore _; + +// Convenient alias +using X = Underscore; + +// Treat Underscore as an integral like integral_constant +template <> +struct is_integral : true_type {}; + +template +struct is_underscore : false_type {}; +template <> +struct is_underscore : true_type {}; + +// Tuple trait for detecting static member element +template +struct has_elem : false_type {}; +template +struct has_elem : true_type {}; +template +struct has_elem::value> > + : has_elem > {}; +template +struct has_elem> + : disjunction, Elem>...> {}; + +// Tuple trait for detecting static member element +template +struct all_elem : false_type {}; +template +struct all_elem : true_type {}; +template +struct all_elem::value> > + : all_elem > {}; +template +struct all_elem> + : conjunction, Elem>...> {}; + +// Tuple trait for detecting Underscore member +template +using has_underscore = has_elem; + +template +using all_underscore = all_elem; + +template +using has_int1 = has_elem>; + +template +using has_int0 = has_elem>; + +// +// Slice keeps only the elements of Tuple B that are paired with an Underscore +// + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +lift_slice(A const& a, B const& b) +{ + if constexpr (is_tuple::value) { + static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); + return filter_tuple(a, b, [](auto const& x, auto const& y) { return lift_slice(x,y); }); + } else if constexpr (is_underscore::value) { + return cute::tuple{b}; + } else { + return cute::tuple<>{}; + } + + CUTE_GCC_UNREACHABLE; +} + +} // end namespace detail + +// Entry point overrides the lifting so that slice(_,b) == b +template +CUTE_HOST_DEVICE constexpr +auto +slice(A const& a, B const& b) +{ + if constexpr (is_tuple::value) { + static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); + return filter_tuple(a, b, [](auto const& x, auto const& y) { return detail::lift_slice(x,y); }); + } else if constexpr (is_underscore::value) { + return b; + } else { + return cute::tuple<>{}; + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Dice keeps only the elements of Tuple B that are paired with an Int +// + +namespace detail { + +template +CUTE_HOST_DEVICE constexpr +auto +lift_dice(A const& a, B const& b) +{ + if constexpr (is_tuple::value) { + static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); + return filter_tuple(a, b, [](auto const& x, auto const& y) { return lift_dice(x,y); }); + } else if constexpr (is_underscore::value) { + return cute::tuple<>{}; + } else { + return cute::tuple{b}; + } + + CUTE_GCC_UNREACHABLE; +} + +} // end namespace detail + +// Entry point overrides the lifting so that dice(1,b) == b +template +CUTE_HOST_DEVICE constexpr +auto +dice(A const& a, B const& b) +{ + if constexpr (is_tuple::value) { + static_assert(tuple_size::value == tuple_size::value, "Mismatched Ranks"); + return filter_tuple(a, b, [](auto const& x, auto const& y) { return detail::lift_dice(x,y); }); + } else if constexpr (is_underscore::value) { + return cute::tuple<>{}; + } else { + return b; + } + + CUTE_GCC_UNREACHABLE; +} + +// +// Display utilities +// + +CUTE_HOST_DEVICE void print(Underscore const&) { + printf("_"); +} + +#if !defined(__CUDACC_RTC__) +CUTE_HOST std::ostream& operator<<(std::ostream& os, Underscore const&) { + return os << "_"; +} +#endif + +} // end namespace cute diff --git a/include/cute/util/debug.hpp b/include/cute/util/debug.hpp new file mode 100644 index 0000000000..2645444369 --- /dev/null +++ b/include/cute/util/debug.hpp @@ -0,0 +1,164 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +/** + * \file + * \brief Debugging and logging functionality + */ + +#include + +#include + +namespace cute +{ + +/****************************************************************************** + * Debug and logging macros + ******************************************************************************/ + +/** + * Formats and prints the given message to stdout + */ +#if !defined(CUTE_LOG) +# if !defined(__CUDA_ARCH__) +# define CUTE_LOG(format, ...) printf(format, __VA_ARGS__) +# else +# define CUTE_LOG(format, ...) \ + printf("[block (%d,%d,%d), thread (%d,%d,%d)]: " format, \ + blockIdx.x, blockIdx.y, blockIdx.z, \ + threadIdx.x, threadIdx.y, threadIdx.z, \ + __VA_ARGS__); +# endif +#endif + +/** + * Formats and prints the given message to stdout only if DEBUG is defined + */ +#if !defined(CUTE_LOG_DEBUG) +# ifdef DEBUG +# define CUTE_LOG_DEBUG(format, ...) CUTE_LOG(format, __VA_ARGS__) +# else +# define CUTE_LOG_DEBUG(format, ...) +# endif +#endif + +/** + * \brief Perror macro with exit + */ +#if !defined(CUTE_ERROR_EXIT) +# define CUTE_ERROR_EXIT(e) \ + do { \ + cudaError_t code = (e); \ + if (code != cudaSuccess) { \ + fprintf(stderr, "<%s:%d> %s:\n %s: %s\n", \ + __FILE__, __LINE__, #e, \ + cudaGetErrorName(code), cudaGetErrorString(code)); \ + fflush(stderr); \ + exit(1); \ + } \ + } while (0) +#endif + +#if !defined(CUTE_CHECK_LAST) +# define CUTE_CHECK_LAST() CUTE_ERROR_EXIT(cudaPeekAtLastError()); CUTE_ERROR_EXIT(cudaDeviceSynchronize()) +#endif + +#if !defined(CUTE_CHECK_ERROR) +# define CUTE_CHECK_ERROR(e) CUTE_ERROR_EXIT(e) +#endif + +// A dummy function that uses compilation failure to print a type +template +CUTE_HOST_DEVICE void +print_type() { + static_assert(sizeof...(T) < 0, "Printing type T."); +} + +template +CUTE_HOST_DEVICE void +print_type(T&&...) { + static_assert(sizeof...(T) < 0, "Printing type T."); +} + +// +// Device-specific helpers +// +// e.g. +// if (thread0()) print(...); +// if (block0()) print(...); +// if (thread(42)) print(...); + +CUTE_HOST_DEVICE +bool +block([[maybe_unused]] int bid) +{ +#if defined(__CUDA_ARCH__) + return blockIdx.x + blockIdx.y*gridDim.x + blockIdx.z*gridDim.x*gridDim.y == static_cast(bid); +#else + return true; +#endif +} + +CUTE_HOST_DEVICE +bool +thread([[maybe_unused]] int tid, [[maybe_unused]] int bid) +{ +#if defined(__CUDA_ARCH__) + return (threadIdx.x + threadIdx.y*blockDim.x + threadIdx.z*blockDim.x*blockDim.y == static_cast(tid)) && block(bid); +#else + return true; +#endif +} + +CUTE_HOST_DEVICE +bool +thread(int tid) +{ + return thread(tid,0); +} + +CUTE_HOST_DEVICE +bool +thread0() +{ + return thread(0,0); +} + +CUTE_HOST_DEVICE +bool +block0() +{ + return block(0); +} + +} // end namespace cute diff --git a/include/cute/util/print.hpp b/include/cute/util/print.hpp new file mode 100644 index 0000000000..dbd6581693 --- /dev/null +++ b/include/cute/util/print.hpp @@ -0,0 +1,261 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // CUTE_HOST_DEVICE +#include // cute::is_valid +#include + +// +// CUDA compatible print and printf +// + +namespace cute +{ + +CUTE_HOST_DEVICE +int +num_digits(int x) +{ + return (x < 10 ? 1 : + (x < 100 ? 2 : + (x < 1000 ? 3 : + (x < 10000 ? 4 : + (x < 100000 ? 5 : + (x < 1000000 ? 6 : + (x < 10000000 ? 7 : + (x < 100000000 ? 8 : + (x < 1000000000 ? 9 : + 10))))))))); +} + +// +// print dispatcher +// + +CUTE_HOST_DEVICE +void +print(char c) { + printf("%c", c); +} + +CUTE_HOST_DEVICE +void +print(signed char a) { + printf("%d", static_cast(a)); +} + +CUTE_HOST_DEVICE +void +print(unsigned char a) { + printf("%u", static_cast(a)); +} + +CUTE_HOST_DEVICE +void +print(short a) { + printf("%hd", a); +} + +CUTE_HOST_DEVICE +void +print(unsigned short a) { + printf("%hu", a); +} + +CUTE_HOST_DEVICE +void +print(int a) { + printf("%d", a); +} + +CUTE_HOST_DEVICE +void +print(uint1b_t a) { + printf("%d", int(a)); +} + +CUTE_HOST_DEVICE +void +print(int2b_t a) { + printf("%d", int(a)); +} + +CUTE_HOST_DEVICE +void +print(uint2b_t a) { + printf("%d", int(a)); +} + +CUTE_HOST_DEVICE +void +print(int4b_t a) { + printf("%d", int(a)); +} + +CUTE_HOST_DEVICE +void +print(uint4b_t a) { + printf("%d", int(a)); +} + +CUTE_HOST_DEVICE +void +print(bin1_t a) { + printf("%d", int(a)); +} + +CUTE_HOST_DEVICE +void +print(unsigned int a) { + printf("%u", a); +} + +CUTE_HOST_DEVICE +void +print(long a) { + printf("%ld", a); +} + +CUTE_HOST_DEVICE +void +print(unsigned long a) { + printf("%lu", a); +} + +CUTE_HOST_DEVICE +void +print(long long a) { + printf("%lld", a); +} + +CUTE_HOST_DEVICE +void +print(unsigned long long a) { + printf("%llu", a); +} + +CUTE_HOST_DEVICE +void +print(float a) { + printf("%f", a); +} + +CUTE_HOST_DEVICE +void +print(double a) { + printf("%f", a); +} + +template +CUTE_HOST_DEVICE +void +print(char const* format, T const&... t) { + printf(format, t...); +} + +CUTE_HOST_DEVICE +void +print(char const* format) { + printf("%s", format); +} + +// +// pretty printing +// + +CUTE_HOST_DEVICE void +pretty_print(uint1b_t a) { + printf("%*d", 3, int(a)); +} + +CUTE_HOST_DEVICE void +pretty_print(int2b_t a) { + printf("%*d", 5, int(a)); +} + +CUTE_HOST_DEVICE void +pretty_print(uint2b_t a) { + printf("%*d", 5, int(a)); +} + +CUTE_HOST_DEVICE void +pretty_print(int4b_t a) { + printf("%*d", 5, int(a)); +} + +CUTE_HOST_DEVICE void +pretty_print(uint4b_t a) { + printf("%*d", 5, int(a)); +} + +CUTE_HOST_DEVICE void +pretty_print(bool v) { + printf("%*d", 3, int(v)); +} + +CUTE_HOST_DEVICE void +pretty_print(int32_t v) { + printf("%*d", 5, v); +} + +CUTE_HOST_DEVICE void +pretty_print(uint32_t v) { + printf("%*d", 5, v); +} + +CUTE_HOST_DEVICE void +pretty_print(int64_t v) { + printf("%*lld", 5, static_cast(v)); +} + +CUTE_HOST_DEVICE void +pretty_print(uint64_t v) { + printf("%*llu", 5, static_cast(v)); +} + +CUTE_HOST_DEVICE void +pretty_print(float v) { + printf("%*.2e", 10, v); +} + +CUTE_HOST_DEVICE void +pretty_print(double v) { + printf("%*.3e", 11, v); +} + +template +CUTE_HOST_DEVICE void +pretty_print(T t) { + printf(" "); print(t); +} + +} // end namespace cute diff --git a/include/cute/util/type_traits.hpp b/include/cute/util/type_traits.hpp new file mode 100644 index 0000000000..a3074ef947 --- /dev/null +++ b/include/cute/util/type_traits.hpp @@ -0,0 +1,298 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#if defined(__CUDACC_RTC__) +#include +#include +#include +#include +#include +#else +#include +#include // tuple_size, tuple_element +#include // ptrdiff_t +#include // uintptr_t +#include // numeric_limits +#endif + +#include // CUTE_STL_NAMESPACE + +namespace cute +{ + using CUTE_STL_NAMESPACE::enable_if; + using CUTE_STL_NAMESPACE::enable_if_t; +} + +#define __CUTE_REQUIRES(...) typename cute::enable_if<(__VA_ARGS__)>::type* = nullptr +#define __CUTE_REQUIRES_V(...) typename cute::enable_if::type* = nullptr + +namespace cute +{ + +// +using CUTE_STL_NAMESPACE::conjunction; +using CUTE_STL_NAMESPACE::conjunction_v; + +using CUTE_STL_NAMESPACE::disjunction; +using CUTE_STL_NAMESPACE::disjunction_v; + +using CUTE_STL_NAMESPACE::negation; +using CUTE_STL_NAMESPACE::negation_v; + +using CUTE_STL_NAMESPACE::void_t; +using CUTE_STL_NAMESPACE::is_void_v; + +using CUTE_STL_NAMESPACE::is_base_of; +using CUTE_STL_NAMESPACE::is_base_of_v; + +using CUTE_STL_NAMESPACE::is_const; +using CUTE_STL_NAMESPACE::is_const_v; +using CUTE_STL_NAMESPACE::is_volatile; +using CUTE_STL_NAMESPACE::is_volatile_v; + +// Defined in cute/numeric/integral_constant.hpp +// using CUTE_STL_NAMESPACE::true_type; +// using CUTE_STL_NAMESPACE::false_type; + +using CUTE_STL_NAMESPACE::conditional; +using CUTE_STL_NAMESPACE::conditional_t; + +using CUTE_STL_NAMESPACE::add_const_t; + +using CUTE_STL_NAMESPACE::remove_const_t; +using CUTE_STL_NAMESPACE::remove_cv_t; +using CUTE_STL_NAMESPACE::remove_reference_t; + +using CUTE_STL_NAMESPACE::extent; +using CUTE_STL_NAMESPACE::remove_extent; + +using CUTE_STL_NAMESPACE::decay; +using CUTE_STL_NAMESPACE::decay_t; + +using CUTE_STL_NAMESPACE::is_lvalue_reference; +using CUTE_STL_NAMESPACE::is_lvalue_reference_v; + +using CUTE_STL_NAMESPACE::is_reference; +using CUTE_STL_NAMESPACE::is_trivially_copyable; + +using CUTE_STL_NAMESPACE::is_convertible; +using CUTE_STL_NAMESPACE::is_convertible_v; + +using CUTE_STL_NAMESPACE::is_same; +using CUTE_STL_NAMESPACE::is_same_v; + +using CUTE_STL_NAMESPACE::is_constructible; +using CUTE_STL_NAMESPACE::is_constructible_v; +using CUTE_STL_NAMESPACE::is_default_constructible; +using CUTE_STL_NAMESPACE::is_default_constructible_v; +using CUTE_STL_NAMESPACE::is_standard_layout; +using CUTE_STL_NAMESPACE::is_standard_layout_v; + +using CUTE_STL_NAMESPACE::is_arithmetic; +using CUTE_STL_NAMESPACE::is_unsigned; +using CUTE_STL_NAMESPACE::is_unsigned_v; +using CUTE_STL_NAMESPACE::is_signed; +using CUTE_STL_NAMESPACE::is_signed_v; + +using CUTE_STL_NAMESPACE::make_signed; +using CUTE_STL_NAMESPACE::make_signed_t; + +// using CUTE_STL_NAMESPACE::is_integral; +template +using is_std_integral = CUTE_STL_NAMESPACE::is_integral; + +using CUTE_STL_NAMESPACE::is_empty; +using CUTE_STL_NAMESPACE::is_empty_v; + +using CUTE_STL_NAMESPACE::invoke_result_t; + +using CUTE_STL_NAMESPACE::common_type; +using CUTE_STL_NAMESPACE::common_type_t; + +using CUTE_STL_NAMESPACE::remove_pointer; +using CUTE_STL_NAMESPACE::remove_pointer_t; + +using CUTE_STL_NAMESPACE::add_pointer; +using CUTE_STL_NAMESPACE::add_pointer_t; + +using CUTE_STL_NAMESPACE::alignment_of; +using CUTE_STL_NAMESPACE::alignment_of_v; + +using CUTE_STL_NAMESPACE::is_pointer; +using CUTE_STL_NAMESPACE::is_pointer_v; + +// +using CUTE_STL_NAMESPACE::declval; + +template +constexpr T&& forward(remove_reference_t& t) noexcept +{ + return static_cast(t); +} + +template +constexpr T&& forward(remove_reference_t&& t) noexcept +{ + static_assert(! is_lvalue_reference_v, "T cannot be an lvalue reference (e.g., U&)."); + return static_cast(t); +} + +template +constexpr remove_reference_t&& move(T&& t) noexcept +{ + return static_cast&&>(t); +} + +// +using CUTE_STL_NAMESPACE::numeric_limits; + +// +using CUTE_STL_NAMESPACE::ptrdiff_t; + +// +using CUTE_STL_NAMESPACE::uintptr_t; + +// C++20 +// using std::remove_cvref; +template +struct remove_cvref { + using type = remove_cv_t>; +}; + +// C++20 +// using std::remove_cvref_t; +template +using remove_cvref_t = typename remove_cvref::type; + +// +// dependent_false +// +// @brief An always-false value that depends on one or more template parameters. +// See +// https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2019/p1830r1.pdf +// https://github.com/cplusplus/papers/issues/572 +// https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2022/p2593r0.html +template +inline constexpr bool dependent_false = false; + +// +// tuple_size, tuple_element +// +// @brief CuTe-local tuple-traits to prevent conflicts with other libraries. +// For cute:: types, we specialize std::tuple-traits, which is explicitly allowed. +// cute::tuple, cute::array, cute::array_subbyte, etc +// But CuTe wants to treat some external types as tuples as well. For those, +// we specialize cute::tuple-traits to avoid polluting external traits. +// dim3, uint3, etc + +template +struct tuple_size; + +template +struct tuple_size::type>> : CUTE_STL_NAMESPACE::integral_constant::value> {}; + +// S = : std::integral_constant::value> {}; + +template +constexpr size_t tuple_size_v = tuple_size::value; + +template +struct tuple_element; + +template +struct tuple_element::type>> : CUTE_STL_NAMESPACE::tuple_element {}; + +template +using tuple_element_t = typename tuple_element::type; + +// +// is_valid +// + +namespace detail { + +template ()(declval()...))> +CUTE_HOST_DEVICE constexpr auto +is_valid_impl(int) { return CUTE_STL_NAMESPACE::true_type{}; } + +template +CUTE_HOST_DEVICE constexpr auto +is_valid_impl(...) { return CUTE_STL_NAMESPACE::false_type{}; } + +template +struct is_valid_fn { + template + CUTE_HOST_DEVICE constexpr auto + operator()(Args&&...) const { return is_valid_impl(int{}); } +}; + +} // end namespace detail + +template +CUTE_HOST_DEVICE constexpr auto +is_valid(F&&) { + return detail::is_valid_fn{}; +} + +template +CUTE_HOST_DEVICE constexpr auto +is_valid(F&&, Args&&...) { + return detail::is_valid_impl(int{}); +} + +template class True, template class False> +struct conditional_template { + template + using type = True; +}; + +template class True, template class False> +struct conditional_template { + template + using type = False; +}; + +// +// is_any_of +// + +// Member `value` is true if and only if T is same as (is_same_v) at least one of the types in Us +template +struct is_any_of { + constexpr static bool value = (... || CUTE_STL_NAMESPACE::is_same_v); +}; + +// Is true if and only if T is same as (is_same_v) at least one of the types in Us +template +inline constexpr bool is_any_of_v = is_any_of::value; + +} // end namespace cute diff --git a/include/cutlass/aligned_buffer.h b/include/cutlass/aligned_buffer.h index f869d388b0..0d2bb29048 100644 --- a/include/cutlass/aligned_buffer.h +++ b/include/cutlass/aligned_buffer.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/include/cutlass/arch/arch.h b/include/cutlass/arch/arch.h index 578e6c14a3..36d4676bdf 100644 --- a/include/cutlass/arch/arch.h +++ b/include/cutlass/arch/arch.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -34,12 +34,14 @@ #pragma once +#include "cutlass/cutlass.h" + //////////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass { namespace arch { -#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) +#if defined(__NVCC__) || defined(__CUDACC_RTC__) || (defined(__clang__) && defined(__CUDA__)) /// Computes laneId within a warp CUTLASS_DEVICE @@ -84,6 +86,12 @@ struct Sm80 { struct Sm86 { static int const kMinComputeCapability = 86; }; +struct Sm89 { + static int const kMinComputeCapability = 89; +}; +struct Sm90 { + static int const kMinComputeCapability = 90; +}; /// Triggers a breakpoint on the device CUTLASS_DEVICE diff --git a/include/cutlass/arch/barrier.h b/include/cutlass/arch/barrier.h new file mode 100644 index 0000000000..460531aa89 --- /dev/null +++ b/include/cutlass/arch/barrier.h @@ -0,0 +1,723 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Barrier Operations on SM90+ +*/ + +#pragma once + +#include +#include +#if defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && (__CUDACC_VER_MAJOR__ >= 12) +#define CUDA_BARRIER_ENABLED 1 +#else +#define CUDA_BARRIER_ENABLED 0 +#endif + +namespace cutlass { +/// @brief +namespace arch { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +CUTLASS_DEVICE void fence_view_async_shared(); + +namespace detail { // namespace detail begin + +// Single threaded versions that need to be called in an elect_one region +template +CUTLASS_DEVICE +void initialize_barrier_array(T ptr, int arv_cnt) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Stages; i++) { + ptr[i].init(arv_cnt); + } +} + +template +CUTLASS_DEVICE +void initialize_barrier_array(uint64_t *ptr, int arv_cnt) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Stages; i++) { + T::init(&ptr[i], arv_cnt); + } +} + +template +CUTLASS_DEVICE +void initialize_barrier_array_pair(FullBarrier full_barriers, EmptyBarrier empty_barriers, int full_barrier_arv_cnt, int empty_barrier_arv_cnt) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Stages; i++) { + full_barriers[i].init(full_barrier_arv_cnt); + empty_barriers[i].init(empty_barrier_arv_cnt); + } +} + +template +CUTLASS_DEVICE +void initialize_barrier_array_pair(uint64_t *full_barriers_ptr, uint64_t *empty_barriers_ptr, int full_barrier_arv_cnt, int empty_barrier_arv_cnt) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Stages; i++) { + FullBarrier::init(&full_barriers_ptr[i], full_barrier_arv_cnt); + EmptyBarrier::init(&empty_barriers_ptr[i], empty_barrier_arv_cnt); + } +} + +// Aligned versions that need to be call warp wide +template +CUTLASS_DEVICE +void initialize_barrier_array_aligned(T ptr, int arv_cnt) { + if(cute::elect_one_sync()) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Stages; i++) { + ptr[i].init(arv_cnt); + } + } +} + +template +CUTLASS_DEVICE +void initialize_barrier_array_aligned(uint64_t *ptr, int arv_cnt) { + if(cute::elect_one_sync()) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Stages; i++) { + T::init(&ptr[i], arv_cnt); + } + } +} + +template +CUTLASS_DEVICE +void initialize_barrier_array_pair_aligned(FullBarrier full_barriers, EmptyBarrier empty_barriers, int full_barrier_arv_cnt, int empty_barrier_arv_cnt) { + if(cute::elect_one_sync()) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Stages; i++) { + full_barriers[i].init(full_barrier_arv_cnt); + empty_barriers[i].init(empty_barrier_arv_cnt); + } + } +} + +template +CUTLASS_DEVICE +void initialize_barrier_array_pair_aligned(uint64_t *full_barriers_ptr, uint64_t *empty_barriers_ptr, int full_barrier_arv_cnt, int empty_barrier_arv_cnt) { + if(cute::elect_one_sync()) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Stages; i++) { + FullBarrier::init(&full_barriers_ptr[i], full_barrier_arv_cnt); + EmptyBarrier::init(&empty_barriers_ptr[i], empty_barrier_arv_cnt); + } + } +} + +} // namespace detail end + + +// Enumerates the reserved named barriers to avoid potential conflicts +// This enum class specifies the NamedBarriers reserved by CUTLASS. +enum class ReservedNamedBarriers { + EpilogueBarrier = 1, + TransposeBarrier = 2, + TransformBarrier = 3, + StreamkBarrier0 = 4, + StreamkBarrier1 = 5 + , FirstUserBarrier = StreamkBarrier1 + 1 +}; + + +class NamedBarrier { + + // Data Members: + + // Range = [1 , NUM_THREADS_PER_CTA] + // Range % warp-size (i.e 32) == 0 + uint32_t const num_threads_; + + // Range : [0, 15] + // Note that should be set to the final barrier ID, including ReserveNamedBarrierCount should be considered + uint32_t const id_; + + public: + + // Constructor for CUTLASS developers: + // effective barrier ID starts from 0 + CUTLASS_DEVICE + NamedBarrier(uint32_t num_threads, ReservedNamedBarriers reserved_named_barriers) + : num_threads_(num_threads), id_(static_cast(reserved_named_barriers)) {} + + // Constructor for CUTLASS users: + // effective barrier ID starts from ReservedNamedBarrierCount + CUTLASS_DEVICE + NamedBarrier(uint32_t num_threads, uint32_t id = 0) + : num_threads_(num_threads), id_(id + ReservedNamedBarrierCount) { + CUTLASS_ASSERT(id + ReservedNamedBarrierCount <= HardwareMaxNumNamedBarriers && "Effective barrier_id should not exceed 16."); + } + + CUTLASS_DEVICE + void arrive_and_wait() const { + // Note: The value of id_ is already the final barrier id (set correctly in the constructor). + NamedBarrier::arrive_and_wait_internal(num_threads_, id_); + } + + CUTLASS_DEVICE + void arrive_and_wait_unaligned() const { + // Note: The value of id_ is already the final barrier id (set correctly in the constructor). + NamedBarrier::arrive_and_wait_internal_unaligned(num_threads_, id_); + } + + CUTLASS_DEVICE + void arrive() const { + // Note: The value of id_ is already the final barrier id (set correctly in the constructor). + NamedBarrier::arrive_internal(num_threads_, id_); + } + + CUTLASS_DEVICE + void arrive_unaligned() const { + // Note: The value of id_ is already the final barrier id (set correctly in the constructor). + NamedBarrier::arrive_internal_unaligned(num_threads_, id_); + } + + CUTLASS_DEVICE + void sync() const { + NamedBarrier::arrive_and_wait(); + } + + // Static variants + + // Calling interface for CUTLASS users: + // effective barrier ID starts from ReservedNamedBarrierCount + CUTLASS_DEVICE + static void arrive_and_wait(uint32_t num_threads, uint32_t barrier_id) { + arrive_and_wait_internal(num_threads, barrier_id + ReservedNamedBarrierCount); + } + + // Calling interface for CUTLASS developers: + // effective barrier ID starts from 0 + CUTLASS_DEVICE + static void arrive_and_wait(uint32_t num_threads, ReservedNamedBarriers reserved_named_barriers) { + arrive_and_wait_internal(num_threads, static_cast(reserved_named_barriers)); + } + + // Calling interface for CUTLASS users: + // effective barrier ID starts from ReservedNamedBarrierCount + CUTLASS_DEVICE + static void arrive(uint32_t num_threads, uint32_t barrier_id) { + arrive_internal(num_threads, barrier_id + ReservedNamedBarrierCount); + } + + // Calling interface for CUTLASS developers: + // effective barrier ID starts from 0 + CUTLASS_DEVICE + static void arrive(uint32_t num_threads, ReservedNamedBarriers reserved_named_barriers) { + arrive_internal(num_threads, static_cast(reserved_named_barriers)); + } + + // Calling interface for CUTLASS users: + // effective barrier ID starts from ReservedNamedBarrierCount + CUTLASS_DEVICE + static void sync(uint32_t num_threads, uint32_t barrier_id) { + sync_internal(num_threads, barrier_id + ReservedNamedBarrierCount); + } + + // Calling interface for CUTLASS developers: + // effective barrier ID starts from 0 + CUTLASS_DEVICE + static void sync(uint32_t num_threads, ReservedNamedBarriers reserved_named_barriers) { + sync_internal(num_threads, static_cast(reserved_named_barriers)); + } + + + private: + CUTLASS_DEVICE + static void arrive_and_wait_internal(uint32_t num_threads, uint32_t barrier_id) { +#if CUDA_BARRIER_ENABLED + asm volatile("bar.sync %0, %1;" : : "r"(barrier_id), "r"(num_threads)); + cutlass::arch::synclog_emit_named_barrier_arrive_and_wait(__LINE__, num_threads, barrier_id); +#elif defined(__CUDA_ARCH__) + asm volatile ("brkpt;\n" ::); +#endif + } + + CUTLASS_DEVICE + static void arrive_and_wait_internal_unaligned(uint32_t num_threads, uint32_t barrier_id) { +#if CUDA_BARRIER_ENABLED + asm volatile("barrier.sync %0, %1;" : : "r"(barrier_id), "r"(num_threads)); + cutlass::arch::synclog_emit_named_barrier_arrive_and_wait(__LINE__, num_threads, barrier_id); +#elif defined(__CUDA_ARCH__) + asm volatile ("brkpt;\n" ::); +#endif + } + + CUTLASS_DEVICE + static void arrive_internal(uint32_t num_threads, uint32_t barrier_id) { +#if CUDA_BARRIER_ENABLED + cutlass::arch::synclog_emit_named_barrier_arrive(__LINE__, num_threads, barrier_id); + asm volatile("bar.arrive %0, %1;" : : "r"(barrier_id), "r"(num_threads)); +#elif defined(__CUDA_ARCH__) + asm volatile ("brkpt;\n" ::); +#endif + } + + CUTLASS_DEVICE + static void arrive_internal_unaligned(uint32_t num_threads, uint32_t barrier_id) { +#if CUDA_BARRIER_ENABLED + cutlass::arch::synclog_emit_named_barrier_arrive(__LINE__, num_threads, barrier_id); + asm volatile("barrier.arrive %0, %1;" : : "r"(barrier_id), "r"(num_threads)); +#elif defined(__CUDA_ARCH__) + asm volatile ("brkpt;\n" ::); +#endif + } + + CUTLASS_DEVICE + static void sync_internal(uint32_t num_threads, uint32_t barrier_id) { + NamedBarrier::arrive_and_wait_internal(num_threads, barrier_id); + } + + public: + // Currently we reserve 8 NamedBarriers for CUTLASS' own use cases, + // while leaving the renaming for general users. + static const uint32_t ReservedNamedBarrierCount = static_cast(ReservedNamedBarriers::FirstUserBarrier); + static const uint32_t HardwareMaxNumNamedBarriers = 16; + +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Hopper introduces a new cluster-wide barrier which handle with Cluster-wide arrive-wait behaviour. +// This is an extension to the Ampere arrive-wait barriers +// Note : Ampere arrive-wait Barriers have a larger max-arrive count (2^30) than Hopper arrive-wait Barriers (2^20). +struct ClusterBarrier { + + using ValueType = uint64_t; + +protected: + // Can never be initialized - can only be aliased to smem + ValueType barrier_; + +public: + + CUTLASS_DEVICE + ClusterBarrier() = delete; + + CUTLASS_DEVICE + void init(uint32_t arrive_count) const { + ClusterBarrier::init(&this->barrier_, arrive_count); + } + + CUTLASS_DEVICE + bool test_wait(uint32_t phase, uint32_t pred=true) const { + return ClusterBarrier::test_wait(&this->barrier_, phase, pred); + } + + CUTLASS_DEVICE + bool try_wait(uint32_t phase) const { + return ClusterBarrier::try_wait(&this->barrier_, phase); + } + + CUTLASS_DEVICE + void wait(uint32_t phase) const { + ClusterBarrier::wait(&this->barrier_, phase); + } + + // Barrier arrive on local smem + CUTLASS_DEVICE + void arrive() const { + ClusterBarrier::arrive(&this->barrier_); + } + + // Remote SMEM arrive with a perdicate (usually done to pick the thread doing the arrive) + CUTLASS_DEVICE + void arrive(uint32_t cta_id, uint32_t pred = true ) const { + ClusterBarrier::arrive(&this->barrier_, cta_id, pred); + } + + // + // Static Versions + // + CUTLASS_DEVICE + static void init(ValueType const* smem_ptr, uint32_t arrive_count) { +#if CUDA_BARRIER_ENABLED + uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + asm volatile( + "{\n\t" + "mbarrier.init.shared::cta.b64 [%1], %0; \n" + "}" + : + : "r"(arrive_count), "r"(smem_addr)); + cutlass::arch::synclog_emit_cluster_barrier_init(__LINE__, smem_addr, arrive_count); +#elif defined(__CUDA_ARCH__) + asm volatile ("brkpt;\n" ::); +#endif + } + + // Static version of wait - in case we don't want to burn a register + CUTLASS_DEVICE + static void wait(ValueType const* smem_ptr, uint32_t phase) { +#if CUDA_BARRIER_ENABLED + uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_cluster_barrier_wait(__LINE__, smem_addr, phase); + // Arbitrarily large timer value after which try-wait expires and re-tries. + uint32_t ticks = 0x989680; + asm volatile( + "{\n\t" + ".reg .pred P1; \n\t" + "LAB_WAIT: \n\t" + "mbarrier.try_wait.parity.shared::cta.b64 P1, [%0], %1, %2; \n\t" + "@P1 bra DONE; \n\t" + "bra LAB_WAIT; \n\t" + "DONE: \n\t" + "}" + : + : "r"(smem_addr), "r"(phase), "r"(ticks)); + +#elif defined(__CUDA_ARCH__) + asm volatile ("brkpt;\n" ::); +#endif + } + + CUTLASS_DEVICE + static bool test_wait(ValueType const* smem_ptr, uint32_t phase, uint32_t pred) { +#if CUDA_BARRIER_ENABLED + uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_cluster_barrier_test_wait(__LINE__, smem_addr, phase, pred); + uint32_t waitComplete; + + asm volatile( + "{\n\t" + ".reg .pred P1; \n\t" + ".reg .pred P2; \n\t" + "setp.eq.u32 P2, %3, 1;\n\t" + "@P2 mbarrier.test_wait.parity.shared::cta.b64 P1, [%1], %2; \n\t" + "selp.b32 %0, 1, 0, P1; \n\t" + "}" + : "=r"(waitComplete) + : "r"(smem_addr), "r"(phase), "r"(pred)); + + return static_cast(waitComplete); +#elif defined(__CUDA_ARCH__) + asm volatile ("brkpt;\n" ::); +#endif + return 0; + } + + CUTLASS_DEVICE + static bool try_wait(ValueType const* smem_ptr, uint32_t phase) { +#if CUDA_BARRIER_ENABLED + uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + cutlass::arch::synclog_emit_cluster_barrier_try_wait(__LINE__, smem_addr, phase); + uint32_t waitComplete; + + asm volatile( + "{\n\t" + ".reg .pred P1; \n\t" + "mbarrier.try_wait.parity.shared::cta.b64 P1, [%1], %2; \n\t" + "selp.b32 %0, 1, 0, P1; \n\t" + "}" + : "=r"(waitComplete) + : "r"(smem_addr), "r"(phase)); + + return static_cast(waitComplete); +#elif defined(__CUDA_ARCH__) + asm volatile ("brkpt;\n" ::); +#endif + return 0; + } + + // Static Predicated version of the above - in case we know the address. + CUTLASS_DEVICE + static void arrive(ValueType const* smem_ptr, uint32_t cta_id, uint32_t pred) { +#if CUDA_BARRIER_ENABLED + uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + if (pred) { + asm volatile( + "{\n\t" + ".reg .b32 remAddr32;\n\t" + "mapa.shared::cluster.u32 remAddr32, %0, %1;\n\t" + "mbarrier.arrive.shared::cluster.b64 _, [remAddr32];\n\t" + "}" + : + : "r"(smem_addr), "r"(cta_id)); + } + + cutlass::arch::synclog_emit_cluster_barrier_arrive_cluster(__LINE__, smem_addr, cta_id, pred); +#elif defined(__CUDA_ARCH__) + asm volatile ("brkpt;\n" ::); +#endif + } + + // Barrier arrive on local smem + CUTLASS_DEVICE + static void arrive(ValueType const* smem_ptr) { +#if CUDA_BARRIER_ENABLED + uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + asm volatile( + "{\n\t" + "mbarrier.arrive.shared::cta.b64 _, [%0];\n\t" + "}" + : + : "r"(smem_addr)); + cutlass::arch::synclog_emit_cluster_barrier_arrive(__LINE__, smem_addr); +#elif defined(__CUDA_ARCH__) + asm volatile ("brkpt;\n" ::); +#endif + } + + CUTLASS_DEVICE + static void invalidate(ValueType const* smem_ptr) { +#if CUDA_BARRIER_ENABLED + uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + asm volatile( + "{\n\t" + "mbarrier.inval.shared::cta.b64 [%0]; \n\t" + "}" + : + : "r"(smem_addr)); +#elif defined(__CUDA_ARCH__) + asm volatile ("brkpt;\n" ::); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// SM90 also introduces a new type of cluster-barrier which supports sync. +// not just based on Arrive Count, but also transaction count (in bytes) +struct ClusterTransactionBarrier : public ClusterBarrier { + + CUTLASS_DEVICE + ClusterTransactionBarrier() = delete; + + // Performs an arrive operation + expected transaction bytes increment + CUTLASS_DEVICE + void arrive_and_expect_tx(uint32_t transaction_bytes) const { + ClusterTransactionBarrier::arrive_and_expect_tx(&this->barrier_, transaction_bytes); + } + + // Performs an arrive operation + expected transaction bytes increment + CUTLASS_DEVICE + void arrive_and_expect_tx(uint32_t transaction_bytes, uint32_t cta_id, uint32_t pred = 1u) const { + ClusterTransactionBarrier::arrive_and_expect_tx(&this->barrier_, transaction_bytes , cta_id, pred); + } + + // Performs an expected transaction bytes increment without doing an arrive operation + CUTLASS_DEVICE + void expect_transaction(uint32_t transaction_bytes) const { + ClusterTransactionBarrier::expect_transaction(&this->barrier_, transaction_bytes); + } + + // Performs an expected transaction bytes decrement without doing an arrive operation + CUTLASS_DEVICE + void complete_transaction(uint32_t transaction_bytes, uint32_t pred = 1) const { + uint32_t cta_rank = cute::block_rank_in_cluster(); + ClusterTransactionBarrier::complete_transaction(&this->barrier_, cta_rank, transaction_bytes, pred); + } + + // Performs an expected transaction bytes decrement without doing an arrive operation + CUTLASS_DEVICE + void complete_transaction(uint32_t dst_cta_id, uint32_t transaction_bytes, uint32_t pred) const { + ClusterTransactionBarrier::complete_transaction(&this->barrier_, dst_cta_id, transaction_bytes, pred); + } + + // + // Static Versions + // + + // Performs an arrive operation + expected transaction bytes increment + CUTLASS_DEVICE + static void arrive_and_expect_tx(ValueType const* smem_ptr, uint32_t transaction_bytes) { +#if CUDA_BARRIER_ENABLED + uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + asm volatile( + "{\n\t" + "mbarrier.arrive.expect_tx.shared::cta.b64 _, [%1], %0; \n\t" + "}" + : + : "r"(transaction_bytes), "r"(smem_addr)); + cutlass::arch::synclog_emit_cluster_transaction_barrier_arrive_and_expect_tx(__LINE__, smem_addr, transaction_bytes); +#elif defined(__CUDA_ARCH__) + asm volatile ("brkpt;\n" ::); +#endif + } + + // Performs an arrive operation + expected transaction bytes increment for a remote cta_id in a Cluster + CUTLASS_DEVICE + static void arrive_and_expect_tx( + ValueType const* smem_ptr, uint32_t transaction_bytes, uint32_t cta_id, uint32_t pred) { +#if CUDA_BARRIER_ENABLED + uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + ".reg .b32 remAddr32;\n\t" + "setp.eq.u32 p, %2, 1;\n\t" + "@p mapa.shared::cluster.u32 remAddr32, %0, %1;\n\t" + "@p mbarrier.arrive.expect_tx.shared::cluster.b64 _, [remAddr32], %3;\n\t" + "}" + : + : "r"(smem_addr), "r"(cta_id), "r"(pred), "r"(transaction_bytes)); +#elif defined(__CUDA_ARCH__) + asm volatile ("brkpt;\n" ::); +#endif + } + + // Performs an expected transaction bytes increment without doing an arrive operation + CUTLASS_DEVICE + static void expect_transaction(ValueType const* smem_ptr, uint32_t transaction_bytes) { +#if CUDA_BARRIER_ENABLED + uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + asm volatile( + "{\n\t" + "mbarrier.expect_tx.shared::cta.b64 [%1], %0; \n\t" + "}" + : + : "r"(transaction_bytes), "r"(smem_addr)); + cutlass::arch::synclog_emit_cluster_transaction_barrier_expect_transaction(__LINE__, smem_addr, transaction_bytes); +#elif defined(__CUDA_ARCH__) + asm volatile ("brkpt;\n" ::); +#endif + } + + // Performs an expected transaction bytes decrement without doing an arrive operation + CUTLASS_DEVICE + static void complete_transaction( + ValueType const* smem_ptr, uint32_t dst_cta_id, uint32_t transaction_bytes, uint32_t pred = 1) { +#if CUDA_BARRIER_ENABLED + uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + smem_addr = cute::set_block_rank(smem_addr, dst_cta_id); + asm volatile( + "{\n\t" + ".reg .pred p;\n\t" + "setp.eq.u32 p, %2, 1;\n\t" + "@p mbarrier.complete_tx.shared::cluster.relaxed.cluster.b64 [%1], %0;" + "}" + : + : "r"(transaction_bytes), "r"(smem_addr), "r"(pred)); + cutlass::arch::synclog_emit_cluster_transaction_barrier_complete_transaction(__LINE__, smem_addr, dst_cta_id, transaction_bytes, pred); +#elif defined(__CUDA_ARCH__) + asm volatile ("brkpt;\n" ::); +#endif + } + + // + // DEPRECATED APIs + // + [[deprecated("Use arrive_and_expect_tx instead")]] CUTLASS_DEVICE + void arrive_and_reset_bytes(uint32_t transaction_bytes) const { + arrive_and_expect_tx(transaction_bytes); + } + [[deprecated("Use arrive_and_expect_tx instead")]] CUTLASS_DEVICE + void arrive_and_reset_bytes(uint32_t transaction_bytes, uint32_t cta_id) const { + arrive_and_expect_tx(transaction_bytes, cta_id); + } + [[deprecated("Use expect_transaction instead")]] CUTLASS_DEVICE + void reset_bytes(uint32_t transaction_bytes) const { + expect_transaction(transaction_bytes); + } + [[deprecated("Use complete_transaction instead")]] CUTLASS_DEVICE + void commit(uint32_t transaction_bytes, uint32_t pred = 1) const { + complete_transaction(transaction_bytes, pred); + } + [[deprecated("Use complete_transaction instead")]] CUTLASS_DEVICE + void commit(uint32_t dst_cta_id, uint32_t transaction_bytes, uint32_t pred) const { + complete_transaction(dst_cta_id, transaction_bytes, pred); + } + [[deprecated("Use arrive_and_expect_tx instead")]] CUTLASS_DEVICE + static void arrive_and_reset_bytes(ValueType const* smem_ptr, uint32_t transaction_bytes) { + arrive_and_expect_tx(smem_ptr, transaction_bytes); + } + [[deprecated("Use arrive_and_expect_tx instead")]] CUTLASS_DEVICE + static void arrive_and_reset_bytes(ValueType const* smem_ptr, uint32_t transaction_bytes, uint32_t cta_id, uint32_t pred) { + arrive_and_expect_tx(smem_ptr, transaction_bytes, cta_id, pred); + } + [[deprecated("Use expect_transaction instead")]] CUTLASS_DEVICE + static void reset_bytes(ValueType const* smem_ptr, uint32_t transaction_bytes) { + expect_transaction(smem_ptr, transaction_bytes); + } + [[deprecated("Use complete_transaction instead")]] CUTLASS_DEVICE + static void commit(ValueType const* smem_ptr, uint32_t dst_cta_id, uint32_t transaction_bytes, uint32_t pred = 1) { + complete_transaction(smem_ptr, dst_cta_id, transaction_bytes, pred); + } +}; + +// Helps with visibility of barrier init operations across warps / cta / cluster +// Available as a separate function so as to batch inits across barriers and fence once +// Note : It must be composed with an appropriate sync instruction with the right scope +// to ensure visibility eg. __syncthreads() or a cluster_arrive() + cluster_wait() +CUTLASS_DEVICE +void fence_barrier_init() { +#if CUDA_BARRIER_ENABLED + cutlass::arch::synclog_emit_fence_barrier_init(__LINE__); + asm volatile( + "{\n\t" + "fence.mbarrier_init.release.cluster; \n" + "}" + ::); +#elif defined(__CUDA_ARCH__) + asm volatile ("brkpt;\n" ::); +#endif +} + +// Issue a shared memory fence for async operations +CUTLASS_DEVICE +void fence_view_async_shared() { +#if CUDA_BARRIER_ENABLED + cutlass::arch::synclog_emit_fence_view_async_shared(__LINE__); + asm volatile ( + "{\n\t" + "fence.proxy.async.shared::cta; \n" + "}" + ::); +#elif defined(__CUDA_ARCH__) + asm volatile ("brkpt;\n" ::); +#endif +} + +// Arrive on completion of in-flight cp.async operations issued by the calling thread +CUTLASS_DEVICE +void cpasync_barrier_arrive(uint64_t const* smem_ptr) { +#if CUDA_BARRIER_ENABLED + uint32_t smem_addr = cute::cast_smem_ptr_to_uint(smem_ptr); + asm volatile( + "{\n\t" + "cp.async.mbarrier.arrive.shared::cta.b64 [%0];\n\t" + "}" + : + : "r"(smem_addr)); + cutlass::arch::synclog_emit_cpasync_barrier_arrive(__LINE__, smem_addr); +#elif defined(__CUDA_ARCH__) + asm volatile ("brkpt;\n" ::); +#endif +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////////////////////////// +} // end namespace arch +} // end namespace cutlass diff --git a/include/cutlass/arch/cache_operation.h b/include/cutlass/arch/cache_operation.h index d84d4790d0..9d2344bf32 100644 --- a/include/cutlass/arch/cache_operation.h +++ b/include/cutlass/arch/cache_operation.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/include/cutlass/arch/config.h b/include/cutlass/arch/config.h new file mode 100644 index 0000000000..0fc60f41db --- /dev/null +++ b/include/cutlass/arch/config.h @@ -0,0 +1,85 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Definitions for architecture macros +*/ + +#pragma once + +#include "cutlass/platform/platform.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// SM90 +#if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 0)) + #define CUTLASS_ARCH_MMA_SM90_SUPPORTED 1 + #if (!defined(CUTLASS_ARCH_MMA_SM90_ENABLED) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 900) + #define CUTLASS_ARCH_MMA_SM90_ENABLED 1 + + #if (!defined(CUTLASS_ARCH_MMA_SM90A_ENABLED) && defined(__CUDA_ARCH_FEAT_SM90_ALL)) + #define CUTLASS_ARCH_MMA_SM90A_ENABLED 1 + #endif + #endif +#endif + +#if (__CUDACC_VER_MAJOR__ >= 12 && __CUDACC_VER_MINOR__ >= 2) + #define CUTLASS_ARCH_MMA_SPARSE_SM90_SUPPORTED +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// SM90 Modifiable +#if (__CUDACC_VER_MAJOR__ > 12 || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 3)) + #define CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_SUPPORTED 1 + #if (!defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_ENABLED) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ == 900) + #define CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90_ENABLED 1 + + #if (!defined(CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90A_ENABLED) && defined(__CUDA_ARCH_FEAT_SM90_ALL)) + #define CUTLASS_ARCH_MMA_MODIFIABLE_TMA_SM90A_ENABLED 1 + #endif + #endif +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// SM90 F64 +#if (__CUDACC_VER_MAJOR__ > 11 || (__CUDACC_VER_MAJOR__ == 11 && __CUDACC_VER_MINOR__ >= 8)) + #define CUTLASS_ARCH_MMA_SM90_F64_MMA_SUPPORTED 1 + #if (!defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) && defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900) + #define CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED 1 + #endif +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/include/cutlass/arch/grid_dependency_control.h b/include/cutlass/arch/grid_dependency_control.h new file mode 100644 index 0000000000..14ef197497 --- /dev/null +++ b/include/cutlass/arch/grid_dependency_control.h @@ -0,0 +1,84 @@ +/*************************************************************************************************** + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Grid dependent control (GDC) helpers for programmatic dependent launches (PDL). +*/ + +#pragma once + +#include "cute/arch/cluster_sm90.hpp" +#include "cutlass/arch/barrier.h" +#include "cutlass/conv/dispatch_policy.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" + +#ifndef CUTLASS_GDC_ENABLED + #if (defined(CUTLASS_ENABLE_GDC_FOR_SM90) && \ + __CUDACC_VER_MAJOR__ >= 12 && \ + defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL)) + #define CUTLASS_GDC_ENABLED + #endif +#endif + +namespace cutlass { +namespace arch { + +// Issuing the launch_dependents instruction hints a dependent kernel to launch earlier +// launch_dependents doesn't impact the functionality but the performance: +// Launching a dependent kernel too early can compete with current kernels, +// while launching too late can lead to a long latency. +CUTLASS_DEVICE +void launch_dependent_grids() { +#if (defined(CUTLASS_GDC_ENABLED)) + asm volatile("griddepcontrol.launch_dependents;"); +#endif +} + +// Issuing the griddepcontrol.wait instruction enforces no global memory access +// prior to this istruction. This ensures the correctness of global memory access +// when launching a dependent kernel earlier. +CUTLASS_DEVICE +void wait_on_dependent_grids() { +#if (defined(CUTLASS_GDC_ENABLED)) + asm volatile("griddepcontrol.wait;"); +#endif +} + +// Enable kernel-level query regarding whether the GDC feature is turned on +#if (defined(CUTLASS_GDC_ENABLED)) +static constexpr bool IsGdcGloballyEnabled = true; +#else +static constexpr bool IsGdcGloballyEnabled = false; +#endif + + +} // namespace arch +} // namespace cutlass diff --git a/include/cutlass/arch/memory.h b/include/cutlass/arch/memory.h index a41110cbc1..db9ad7397c 100644 --- a/include/cutlass/arch/memory.h +++ b/include/cutlass/arch/memory.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -35,6 +35,8 @@ #pragma once #include "cutlass/cutlass.h" +#include "cutlass/arch/cache_operation.h" +#include "cutlass/platform/platform.h" namespace cutlass { namespace arch { @@ -45,7 +47,9 @@ template < /// Fragment type to store loaded data typename AccessType, /// The bytes of loading - int LoadBytes + int LoadBytes, + /// Cache operation + CacheOperation::Kind cache_op = CacheOperation::Always > struct global_load; @@ -59,8 +63,7 @@ struct global_load; #if (((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 4)) || \ (__CUDACC_VER_MAJOR__ > 11)) && \ - defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) && \ - ! (defined(__clang__) && defined(__CUDA__)) + defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) #define CUTLASS_ENABLE_L2_PREFETCH 1 #else #define CUTLASS_ENABLE_L2_PREFETCH 0 @@ -69,10 +72,11 @@ struct global_load; ///////////////////////////////////////////////////////////////////////////////////////////////// // The redundant mov PTX instruction is used to enforce the compiler to -// initialize data to zero before ld.global +// keep the initializing code before ld.global template struct global_load { CUTLASS_DEVICE global_load(AccessType &D, void const *ptr, bool pred_guard) { @@ -108,7 +112,40 @@ struct global_load struct global_load { + CUTLASS_DEVICE + global_load(AccessType &D, void const *ptr, bool pred_guard) { + uint4 *data = reinterpret_cast(&D); + + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %9, 0;\n" + " mov.b32 %0, %10;\n" + " mov.b32 %1, %11;\n" + " mov.b32 %2, %12;\n" + " mov.b32 %3, %13;\n" + " mov.b32 %4, %14;\n" + " mov.b32 %5, %15;\n" + " mov.b32 %6, %16;\n" + " mov.b32 %7, %17;\n" + " @p ld.global.lu.v4.u32 {%0, %1, %2, %3}, [%8];\n" + " @p ld.global.lu.v4.u32 {%4, %5, %6, %7}, [%18];\n" + "}\n" + : "=r"(data[0].x), "=r"(data[0].y), "=r"(data[0].z), "=r"(data[0].w), + "=r"(data[1].x), "=r"(data[1].y), "=r"(data[1].z), "=r"(data[1].w) + : "l"(ptr), "r"((int)pred_guard), "r"(data[0].x), "r"(data[0].y), + "r"(data[0].z), "r"(data[0].w), "r"(data[1].x), "r"(data[1].y), + "r"(data[1].z), "r"(data[1].w), "l"(((uint8_t *)ptr) + 16)); + } +}; + +template +struct global_load { CUTLASS_DEVICE global_load(AccessType &D, void const *ptr, bool pred_guard) { @@ -134,7 +171,31 @@ struct global_load struct global_load { + CUTLASS_DEVICE + global_load(AccessType &D, void const *ptr, bool pred_guard) { + uint4 &data = reinterpret_cast(D); + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %5, 0;\n" + " mov.b32 %0, %6;\n" + " mov.b32 %1, %7;\n" + " mov.b32 %2, %8;\n" + " mov.b32 %3, %9;\n" + " @p ld.global.lu.v4.u32 {%0, %1, %2, %3}, [%4];\n" + "}\n" + : "=r"(data.x), "=r"(data.y), "=r"(data.z), "=r"(data.w) + : "l"(ptr), "r"((int)pred_guard), "r"(data.x), "r"(data.y), "r"(data.z), "r"(data.w)); + } +}; + +template +struct global_load { CUTLASS_DEVICE global_load(AccessType &D, void const *ptr, bool pred_guard) { @@ -159,7 +220,30 @@ struct global_load struct global_load { + CUTLASS_DEVICE + global_load(AccessType &D, void const *ptr, bool pred_guard) { + uint2 &data = reinterpret_cast(D); + + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %3, 0;\n" + " mov.b32 %0, %4;\n" + " mov.b32 %1, %5;\n" + " @p ld.global.lu.v2.u32 {%0, %1}, [%2];\n" + "}\n" + : "=r"(data.x), "=r"(data.y) + : "l"(ptr), "r"((int)pred_guard), "r"(data.x), "r"(data.y)); + } +}; + +template +struct global_load { CUTLASS_DEVICE global_load(AccessType &D, void const *ptr, bool pred_guard) { @@ -183,7 +267,29 @@ struct global_load struct global_load { + CUTLASS_DEVICE + global_load(AccessType &D, void const *ptr, bool pred_guard) { + unsigned &data = reinterpret_cast(D); + + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %2, 0;\n" + " mov.b32 %0, %3;\n" + " @p ld.global.lu.u32 %0, [%1];\n" + "}\n" + : "=r"(data) + : "l"(ptr), "r"((int)pred_guard), "r"(data)); + } +}; + +template +struct global_load { CUTLASS_DEVICE global_load(AccessType &D, void const *ptr, bool pred_guard) { @@ -207,7 +313,29 @@ struct global_load struct global_load { + CUTLASS_DEVICE + global_load(AccessType &D, void const *ptr, bool pred_guard) { + uint16_t &data = reinterpret_cast(D); + + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %2, 0;\n" + " mov.b16 %0, %3;\n" + " @p ld.global.lu.u16 %0, [%1];\n" + "}\n" + : "=h"(data) + : "l"(ptr), "r"((int)pred_guard), "h"(data)); + } +}; + +template +struct global_load { CUTLASS_DEVICE global_load(AccessType &D, void const *ptr, bool pred_guard) { @@ -451,7 +579,7 @@ template <> CUTLASS_DEVICE void shared_store<16>(uint32_t ptr, void const *src) { uint4 const *dst_u128 = reinterpret_cast(src); - asm volatile("ld.shared.v4.u32 [%0], {%1, %2, %3, %4};\n" + asm volatile("st.shared.v4.u32 [%0], {%1, %2, %3, %4};\n" : : "r"(ptr), "r"(dst_u128->x), @@ -468,7 +596,7 @@ void shared_store<16>(uint32_t ptr, void const *src) { ///////////////////////////////////////////////////////////////////////////////////////////////// -#include "memory_sm75.h" -#include "memory_sm80.h" +#include "cutlass/arch/memory_sm75.h" +#include "cutlass/arch/memory_sm80.h" ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/arch/memory_sm75.h b/include/cutlass/arch/memory_sm75.h index d97682ed30..0e957c72ae 100644 --- a/include/cutlass/arch/memory_sm75.h +++ b/include/cutlass/arch/memory_sm75.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -35,7 +35,10 @@ #pragma once #include "cutlass/array.h" +#include "cutlass/detail/helper_macros.hpp" #include "cutlass/layout/matrix.h" +#include "cute/arch/copy_sm75.hpp" +#include "cute/arch/util.hpp" namespace cutlass { namespace arch { @@ -48,7 +51,7 @@ template < /// .x1, .x2, or .x4 int MatrixCount > -inline __device__ void ldsm(Array & D, void const* ptr); +CUTLASS_DEVICE void ldsm(Array & D, void const* ptr); ///////////////////////////////////////////////////////////////////////////////////////////////// // @@ -56,96 +59,24 @@ inline __device__ void ldsm(Array & D, void const* ptr); // ///////////////////////////////////////////////////////////////////////////////////////////////// -#if (__CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2) || (__CUDACC_VER_MAJOR__ >= 11) - -#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 750) -#define CUDA_LDMATRIX_ACTIVATED 1 -#endif - -#define CUDA_LDMATRIX_SUPPORTED 1 -#endif - -///////////////////////////////////////////////////////////////////////////////////////////////// -/* -#if ! defined(CUDA_NVVM_GET_SMEM_POINTER_SUPPORTED) && (__CUDACC_VER_MAJOR__ > 10) - #define CUDA_NVVM_GET_SMEM_POINTER_SUPPORTED 1 -#endif -#if ! defined(CUDA_NVVM_GET_SMEM_POINTER_SUPPORTED) - #define CUDA_NVVM_GET_SMEM_POINTER_SUPPORTED ((__CUDACC_VER_MAJOR__ == 10) && (__CUDACC_VER_MINOR__ >= 1)) -#endif - -#if ! defined(CUDA_NVVM_GET_SMEM_POINTER_ENABLED) - #define CUDA_NVVM_GET_SMEM_POINTER_ENABLED CUDA_NVVM_GET_SMEM_POINTER_SUPPORTED -#endif -*/ - -#if (! defined (__clang__) && __CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2) - extern "C" { - // - // This NVVM intrinsic is subject to change in future versions of CUDA. - // Clients should not call it directly. Rather, they should use the - // cutlass::arch::ldsm<>() template. - // - __device__ uint32_t __nvvm_get_smem_pointer(void *); - } -#endif - -///////////////////////////////////////////////////////////////////////////////////////////////// - /// CUTLASS helper to get SMEM pointer -inline __device__ unsigned cutlass_get_smem_pointer(void *ptr) { - -// We prefer to use the new CVTA intrinsics if they are available, otherwise we will fall back to -// the previous internal intrinsics if they are available. -#if (! defined (__clang__) && defined(__CUDA_ARCH__) && __CUDACC_VER_MAJOR__ >= 11) - // - // This NVVM intrinsic converts an address in shared memory to a plain - // unsigned integer. This is necessary to pass to shared memory instructions - // in inline PTX. - // - // In CUDA 11 and beyond, this replaces __nvvm_get_smem_pointer() [only available in 10.2]. - // - //__device__ size_t __cvta_generic_to_shared(void* ptr); - - /// CUTLASS helper to get SMEM pointer - return static_cast(__cvta_generic_to_shared(ptr)); - -#elif (! defined (__clang__) && defined(__CUDA_ARCH__) && __CUDACC_VER_MAJOR__ == 10 && __CUDACC_VER_MINOR__ >= 2) - - return __nvvm_get_smem_pointer(ptr); - -#elif defined(__CUDA_ARCH__) - - uint32_t smem_ptr; - - asm( - "{ .reg .u64 smem_ptr; cvta.to.shared.u64 smem_ptr, %1; cvt.u32.u64 %0, smem_ptr; }\n" - : "=r"(smem_ptr) : "l"(ptr)); - - return smem_ptr; - -#else - - CUTLASS_UNUSED(ptr); - CUTLASS_NOT_IMPLEMENTED(); - return 0; - -#endif +CUTLASS_DEVICE unsigned cutlass_get_smem_pointer(void *ptr) { + return cute::cast_smem_ptr_to_uint(ptr); } - + /// CUTLASS helper to get SMEM pointer -inline __device__ unsigned cutlass_get_smem_pointer(void const *ptr) { +CUTLASS_DEVICE unsigned cutlass_get_smem_pointer(void const *ptr) { return cutlass_get_smem_pointer(const_cast(ptr)); } ///////////////////////////////////////////////////////////////////////////////////////////////// template <> -inline __device__ void ldsm( +CUTLASS_DEVICE void ldsm( Array & D, void const* ptr) { - #if defined(CUDA_LDMATRIX_ACTIVATED) + #if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED) unsigned addr = cutlass_get_smem_pointer(ptr); @@ -165,11 +96,11 @@ inline __device__ void ldsm( ///////////////////////////////////////////////////////////////////////////////////////////////// template <> -inline __device__ void ldsm( +CUTLASS_DEVICE void ldsm( Array & D, void const* ptr) { - #if defined(CUDA_LDMATRIX_ACTIVATED) + #if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED) unsigned addr = cutlass_get_smem_pointer(ptr); @@ -189,11 +120,11 @@ inline __device__ void ldsm( ///////////////////////////////////////////////////////////////////////////////////////////////// template <> -inline __device__ void ldsm( +CUTLASS_DEVICE void ldsm( Array & D, void const* ptr) { - #if defined(CUDA_LDMATRIX_ACTIVATED) + #if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED) unsigned addr = cutlass_get_smem_pointer(ptr); @@ -217,11 +148,11 @@ inline __device__ void ldsm( ///////////////////////////////////////////////////////////////////////////////////////////////// template <> -inline __device__ void ldsm( +CUTLASS_DEVICE void ldsm( Array & D, void const* ptr) { - #if CUDA_LDMATRIX_ACTIVATED + #if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED) unsigned addr = cutlass_get_smem_pointer(ptr); @@ -241,11 +172,11 @@ inline __device__ void ldsm( ///////////////////////////////////////////////////////////////////////////////////////////////// template <> -inline __device__ void ldsm( +CUTLASS_DEVICE void ldsm( Array & D, void const* ptr) { - #if defined(CUDA_LDMATRIX_ACTIVATED) + #if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED) unsigned addr = cutlass_get_smem_pointer(ptr); @@ -265,11 +196,11 @@ inline __device__ void ldsm( ///////////////////////////////////////////////////////////////////////////////////////////////// template <> -inline __device__ void ldsm( +CUTLASS_DEVICE void ldsm( Array & D, void const* ptr) { - #if defined(CUDA_LDMATRIX_ACTIVATED) + #if defined(CUTE_ARCH_LDSM_SM75_ACTIVATED) unsigned addr = cutlass_get_smem_pointer(ptr); diff --git a/include/cutlass/arch/memory_sm80.h b/include/cutlass/arch/memory_sm80.h index 48a499c5c5..cb0ba4b54b 100644 --- a/include/cutlass/arch/memory_sm80.h +++ b/include/cutlass/arch/memory_sm80.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -36,6 +36,7 @@ #pragma once #include "cutlass/cutlass.h" +#include "cutlass/complex.h" #include "cutlass/arch/memory.h" #include "cutlass/arch/memory_sm75.h" #include "cutlass/arch/cache_operation.h" @@ -53,7 +54,7 @@ namespace arch { /// Initiates an asynchronous copy from global memory to shared memory. /// -/// LDGSTS +/// cp.async /// template < /// Size of the access in bytes @@ -65,7 +66,7 @@ struct cp_async; /// Initiates an asynchronous copy from global memory to shared memory. Rather than predicate /// the entire transfer, zeros are written to SMEM if the guard predicate is false. /// -/// LDGSTS +/// cp.async /// template < /// Size of the access in bytes @@ -77,7 +78,7 @@ struct cp_async_zfill; /// Initiates an asynchronous copy from global memory to shared memory. Rather than predicate /// the entire transfer, nans (0x7eff) are written to SMEM if the guard predicate is false. /// -/// LDGSTS +/// cp.async /// template < /// Size of the access in bytes @@ -89,7 +90,7 @@ struct cp_async_nan; /// Either 0 or 1 are written to SMEM based on input element type /// Used for diagonal elements of triangular matrix of BLAS3 functions /// -/// STS +/// st.shared /// template < /// Type of Element @@ -98,6 +99,9 @@ template < bool IsHermitianData = false> struct cp_async_diag; +static const uint32_t OOB_NAN_F16 = 0x7eff; +static const uint32_t OOB_NAN_F16x2 = ((OOB_NAN_F16 << 16) | OOB_NAN_F16); + //////////////////////////////////////////////////////////////////////////////////////////////////// /// Partial specialization @@ -190,8 +194,8 @@ struct cp_async_nan<16, CacheOperation::Always> { cp_async_nan(void *smem_ptr, void const *global_ptr, bool pred_guard) { #if CUDA_CP_ASYNC_ACTIVATED - static __constant__ uint4 OOB_NAN_F16x8 = {0x7eff7eff, 0x7eff7eff, - 0x7eff7eff, 0x7eff7eff}; + static __constant__ uint4 OOB_NAN_F16x8 = {OOB_NAN_F16x2, OOB_NAN_F16x2, + OOB_NAN_F16x2, OOB_NAN_F16x2}; unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr); @@ -305,7 +309,6 @@ struct cp_async_diag { } }; - //////////////////////////////////////////////////////////////////////////////////////////////////// /// Partial specialization @@ -323,6 +326,7 @@ struct cp_async { "cp.async only supports CacheOperation::Global when access size is 16B."); unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr); + cutlass::arch::synclog_emit_cp_async(__LINE__, smem_int_ptr, global_ptr, pred_guard, SizeInBytes); asm volatile( "{\n" @@ -362,6 +366,7 @@ struct cp_async_zfill { unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr); int src_in_bytes = (pred_guard ? SizeInBytes : 0); + cutlass::arch::synclog_emit_cp_async_zfill(__LINE__, smem_int_ptr, global_ptr, pred_guard, SizeInBytes); asm volatile( #if CUTLASS_ENABLE_L2_PREFETCH @@ -386,6 +391,48 @@ struct cp_async_zfill { } }; +/// Partial specialization +template <> +struct cp_async_nan<16, CacheOperation::Global> { + static int const kSizeInBytes = 16; + + /// Copy with nan fill + CUTLASS_DEVICE + cp_async_nan(void *smem_ptr, void const *global_ptr, bool pred_guard) { + #if CUDA_CP_ASYNC_ACTIVATED + + static __constant__ uint4 OOB_NAN_F16x8 = {OOB_NAN_F16x2, OOB_NAN_F16x2, + OOB_NAN_F16x2, OOB_NAN_F16x2}; + + unsigned smem_int_ptr = cutlass_get_smem_pointer(smem_ptr); + cutlass::arch::synclog_emit_cp_async_nan(__LINE__, smem_int_ptr, global_ptr, pred_guard); + + asm volatile( + "{\n" + " .reg .pred p;\n" + " setp.ne.b32 p, %0, 0;\n" +#if CUTLASS_ENABLE_L2_PREFETCH + " @p cp.async.cg.shared.global.L2::128B [%1], [%2], %3;\n" +#else + " @p cp.async.cg.shared.global [%1], [%2], %3;\n" +#endif + " @!p st.shared.v4.u32 [%1], {%4, %5, %6, %7};\n" + "}\n" + : + : "r"((int)pred_guard), "r"(smem_int_ptr), "l"(global_ptr), + "n"(kSizeInBytes), "r"(OOB_NAN_F16x8.x), "r"(OOB_NAN_F16x8.y), "r"(OOB_NAN_F16x8.z), + "r"(OOB_NAN_F16x8.w)); + + #else + + CUTLASS_UNUSED(smem_ptr); + CUTLASS_UNUSED(global_ptr); + CUTLASS_UNUSED(pred_guard); + CUTLASS_NOT_IMPLEMENTED(); + + #endif + } +}; //////////////////////////////////////////////////////////////////////////////////////////////////// /// Establishes an ordering w.r.t previously issued cp.async instructions. Does not block. @@ -393,6 +440,7 @@ CUTLASS_DEVICE void cp_async_fence() { #if CUDA_CP_ASYNC_ACTIVATED asm volatile("cp.async.commit_group;\n" ::); + cutlass::arch::synclog_emit_cp_async_fence(__LINE__); #endif } @@ -403,6 +451,7 @@ template CUTLASS_DEVICE void cp_async_wait() { #if CUDA_CP_ASYNC_ACTIVATED asm volatile("cp.async.wait_group %0;\n" ::"n"(N)); + cutlass::arch::synclog_emit_cp_async_wait(__LINE__, N); #endif } @@ -411,6 +460,7 @@ template <> CUTLASS_DEVICE void cp_async_wait<0>() { #if CUDA_CP_ASYNC_ACTIVATED asm volatile("cp.async.wait_all;\n" ::); + cutlass::arch::synclog_emit_cp_async_wait_all(__LINE__); #endif } diff --git a/include/cutlass/arch/mma.h b/include/cutlass/arch/mma.h index ce3e02f365..007ba19bed 100644 --- a/include/cutlass/arch/mma.h +++ b/include/cutlass/arch/mma.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -49,61 +49,85 @@ namespace arch { ///////////////////////////////////////////////////////////////////////////////////////////////// /// Tag indicating the operation implied by MMA. -struct OpMultiplyAdd; +struct OpMultiplyAdd {}; ///////////////////////////////////////////////////////////////////////////////////////////////// /// Tag indicating the result is saturated to MAX_FLOAT|MIN_FLOAT or MAX_INT|MIN_INT -struct OpMultiplyAddSaturate; +struct OpMultiplyAddSaturate {}; ///////////////////////////////////////////////////////////////////////////////////////////////// /// Tag indicating the input is converted to a narrower type (BF16) -struct OpMultiplyAddFastBF16; +struct OpMultiplyAddFastBF16 {}; ///////////////////////////////////////////////////////////////////////////////////////////////// /// Tag indicating the input is converted to a narrower type (F16) -struct OpMultiplyAddFastF16; +struct OpMultiplyAddFastF16 {}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tag indicating the input data types are mixed and the narrower type is +/// upcasted to the wider type +struct OpMultiplyAddMixedInputUpcast {}; ///////////////////////////////////////////////////////////////////////////////////////////////// /// Tag indicating the input is converted to 2 (big and small) TF32 components // Perform 3xTF32 or 4xTF32 for every F32 output element -struct OpMultiplyAddFastF32; +struct OpMultiplyAddFastF32 {}; + +///////////////////////////////////////////////////////////////////////////////////////////////// /// Tag indicating the input is converted to 2 (big and small) TF32 components // Perform 3xTF32 or 4xTF32 for every complex output element -struct OpMultiplyAddComplexFastF32; +struct OpMultiplyAddComplexFastF32 {}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tag indicating that staged accumulation is not to be used. This is valid only for SM89 +/// FP8 kernels. +struct OpMultiplyAddFastAccum; ///////////////////////////////////////////////////////////////////////////////////////////////// /// Tag indicating the complex multiply-add operation -struct OpMultiplyAddComplex; +struct OpMultiplyAddComplex {}; ///////////////////////////////////////////////////////////////////////////////////////////////// /// Tag indicating the gaussian complex multiply-add operation -struct OpMultiplyAddGaussianComplex; +struct OpMultiplyAddGaussianComplex {}; ///////////////////////////////////////////////////////////////////////////////////////////////// /// Tag indicating the inner product is defined by (XOR, POPC) -struct OpXorPopc; +struct OpXorPopc {}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tag indicating the inner product is defined by (AND, POPC) +struct OpAndPopc {}; ///////////////////////////////////////////////////////////////////////////////////////////////// /// Tag classifying math operators as thread-level operations. -struct OpClassSimt; +struct OpClassSimt {}; ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Tag classifing operators as Tensor Core operations. -struct OpClassTensorOp; +/// Tag classifying operators as Tensor Core operations. +struct OpClassTensorOp {}; ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Tag classifing operators as WMMA Tensor Core operations -struct OpClassWmmaTensorOp; +/// Tag classifying operators as WMMA Tensor Core operations +struct OpClassWmmaTensorOp {}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Tag classifying operators as Tensor Core with structure sparse operations. +struct OpClassSparseTensorOp {}; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -223,4 +247,23 @@ struct SparseMma; #include "cutlass/arch/mma_sm75.h" #include "cutlass/arch/mma_sm80.h" #include "cutlass/arch/mma_sparse_sm80.h" +#include "cutlass/arch/mma_sm89.h" +#include "cutlass/arch/mma_sparse_sm89.h" +#include "cutlass/arch/mma_sm90.h" +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace arch { +namespace detail { +/// Helper for determining whether staged accumulation should be used for a given operator +template +struct UseStagedAccumulation { + static bool const value = platform::is_same::value || + platform::is_same::value || + is_sm89_staged_policy_v; +}; +} // namespace detail +} // namespace arch +} // namespace cutlass + ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/arch/mma_sm50.h b/include/cutlass/arch/mma_sm50.h index f5458fc8be..98ff18bea0 100644 --- a/include/cutlass/arch/mma_sm50.h +++ b/include/cutlass/arch/mma_sm50.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/include/cutlass/arch/mma_sm60.h b/include/cutlass/arch/mma_sm60.h index 6fa8b6f7c9..3e3c71ef36 100644 --- a/include/cutlass/arch/mma_sm60.h +++ b/include/cutlass/arch/mma_sm60.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/include/cutlass/arch/mma_sm61.h b/include/cutlass/arch/mma_sm61.h index dc90d7868d..82a5aa7280 100644 --- a/include/cutlass/arch/mma_sm61.h +++ b/include/cutlass/arch/mma_sm61.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/include/cutlass/arch/mma_sm70.h b/include/cutlass/arch/mma_sm70.h index 4fe862f4d9..28bb46382c 100644 --- a/include/cutlass/arch/mma_sm70.h +++ b/include/cutlass/arch/mma_sm70.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -33,11 +33,7 @@ */ #pragma once -#if defined(__CUDACC_RTC__) #include -#else -#include -#endif #include "mma.h" #include "cutlass/layout/matrix.h" diff --git a/include/cutlass/arch/mma_sm75.h b/include/cutlass/arch/mma_sm75.h index 816072bd62..a39ededbe0 100644 --- a/include/cutlass/arch/mma_sm75.h +++ b/include/cutlass/arch/mma_sm75.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -34,11 +34,7 @@ #pragma once -#if defined(__CUDACC_RTC__) #include -#else -#include -#endif #include "cutlass/arch/wmma.h" @@ -126,7 +122,11 @@ struct Mma< : "r"(A[0]), "r"(A[1]), "r"(B[0]), "r"(C[0]), "r"(C[1])); #else - assert(0); + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + CUTLASS_NOT_IMPLEMENTED(); #endif } }; @@ -188,242 +188,11 @@ struct Mma< ); #else - assert(0); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////// -// -// Integer matrix multiply .8816 (8b) -// -//////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation: S32 = S8 * S8 + S32 -template <> -struct Mma< - gemm::GemmShape<8, 8, 16>, - 32, - int8_t, - layout::RowMajor, - int8_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAdd> { - - using Shape = gemm::GemmShape<8, 8, 16>; - - using ElementA = int8_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = int8_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm75; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) - - unsigned const & A = reinterpret_cast(a); - unsigned const & B = reinterpret_cast(b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - asm volatile("mma.sync.aligned.m8n8k16.row.col.s32.s8.s8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); - -#else - assert(0); -#endif - } -}; - -/// Matrix multiply-add operation: S32 = U8 * S8 + S32 -template <> -struct Mma< - gemm::GemmShape<8, 8, 16>, - 32, - uint8_t, - layout::RowMajor, - int8_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAdd> { - - using Shape = gemm::GemmShape<8, 8, 16>; - - using ElementA = uint8_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = int8_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm75; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) - - unsigned const & A = reinterpret_cast(a); - unsigned const & B = reinterpret_cast(b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - asm volatile("mma.sync.aligned.m8n8k16.row.col.s32.u8.s8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); - -#else - assert(0); -#endif - } -}; - -/// Matrix multiply-add operation: S32 = S8 * U8 + S32 -template <> -struct Mma< - gemm::GemmShape<8, 8, 16>, - 32, - int8_t, - layout::RowMajor, - uint8_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAdd> { - - using Shape = gemm::GemmShape<8, 8, 16>; - - using ElementA = int8_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = uint8_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm75; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) - - unsigned const & A = reinterpret_cast(a); - unsigned const & B = reinterpret_cast(b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - asm volatile("mma.sync.aligned.m8n8k16.row.col.s8.u8 {%0,%1}, {%2}, {%3}, {%4,%5};\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); - - -#else - assert(0); -#endif - } -}; - -/// Matrix multiply-add operation: S32 = U8 * U8 + S32 -template <> -struct Mma< - gemm::GemmShape<8, 8, 16>, - 32, - uint8_t, - layout::RowMajor, - uint8_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAdd> { - - using Shape = gemm::GemmShape<8, 8, 16>; - - using ElementA = uint8_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = uint8_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm75; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) - - unsigned const & A = reinterpret_cast(a); - unsigned const & B = reinterpret_cast(b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - asm volatile("mma.sync.aligned.m8n8k16.row.col.s32.u8.u8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); - -#else - assert(0); + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + CUTLASS_NOT_IMPLEMENTED(); #endif } }; @@ -437,7 +206,7 @@ struct Mma< /// Matrix multiply-add operation: S32 = S8 * S8 + S32 template <> struct Mma< - gemm::GemmShape<8,8,16>, + gemm::GemmShape<8, 8, 16>, 32, int8_t, layout::RowMajor, @@ -447,7 +216,7 @@ struct Mma< layout::RowMajor, OpMultiplyAddSaturate> { - using Shape = gemm::GemmShape<8,8,16>; + using Shape = gemm::GemmShape<8, 8, 16>; using ElementA = int8_t; using LayoutA = layout::RowMajor; @@ -484,9 +253,12 @@ struct Mma< asm volatile("mma.sync.aligned.m8n8k16.row.col.satfinite.s32.s8.s8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" : "=r"(D[0]), "=r"(D[1]) : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); - #else - assert(0); + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + CUTLASS_NOT_IMPLEMENTED(); #endif } }; @@ -494,7 +266,7 @@ struct Mma< /// Matrix multiply-add operation: S32 = U8 * S8 + S32 template <> struct Mma< - gemm::GemmShape<8,8,16>, + gemm::GemmShape<8, 8, 16>, 32, uint8_t, layout::RowMajor, @@ -504,7 +276,7 @@ struct Mma< layout::RowMajor, OpMultiplyAddSaturate> { - using Shape = gemm::GemmShape<8,8,16>; + using Shape = gemm::GemmShape<8, 8, 16>; using ElementA = uint8_t; using LayoutA = layout::RowMajor; @@ -541,9 +313,12 @@ struct Mma< asm volatile("mma.sync.aligned.m8n8k16.row.col.satfinite.s32.u8.s8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" : "=r"(D[0]), "=r"(D[1]) : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); - #else - assert(0); + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + CUTLASS_NOT_IMPLEMENTED(); #endif } }; @@ -551,7 +326,7 @@ struct Mma< /// Matrix multiply-add operation: S32 = S8 * U8 + S32 template <> struct Mma< - gemm::GemmShape<8,8,16>, + gemm::GemmShape<8, 8, 16>, 32, int8_t, layout::RowMajor, @@ -561,7 +336,7 @@ struct Mma< layout::RowMajor, OpMultiplyAddSaturate> { - using Shape = gemm::GemmShape<8,8,16>; + using Shape = gemm::GemmShape<8, 8, 16>; using ElementA = int8_t; using LayoutA = layout::RowMajor; @@ -598,9 +373,12 @@ struct Mma< asm volatile("mma.sync.aligned.m8n8k16.row.col.satfinite.s32.s8.u8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" : "=r"(D[0]), "=r"(D[1]) : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); - #else - assert(0); + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + CUTLASS_NOT_IMPLEMENTED(); #endif } }; @@ -608,7 +386,7 @@ struct Mma< /// Matrix multiply-add operation: S32 = U8 * U8 + S32 template <> struct Mma< - gemm::GemmShape<8,8,16>, + gemm::GemmShape<8, 8, 16>, 32, uint8_t, layout::RowMajor, @@ -618,7 +396,7 @@ struct Mma< layout::RowMajor, OpMultiplyAddSaturate> { - using Shape = gemm::GemmShape<8,8,16>; + using Shape = gemm::GemmShape<8, 8, 16>; using ElementA = uint8_t; using LayoutA = layout::RowMajor; @@ -655,243 +433,12 @@ struct Mma< asm volatile("mma.sync.aligned.m8n8k16.row.col.satfinite.s32.u8.u8.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" : "=r"(D[0]), "=r"(D[1]) : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); - -#else - assert(0); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////// -// -// Integer matrix multiply (4b) -// -//////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation: S32 = S4 * S4 + S32 -template <> -struct Mma< - gemm::GemmShape<8,8,32>, - 32, - int4b_t, - layout::RowMajor, - int4b_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAdd> { - - using Shape = gemm::GemmShape<8,8,32>; - - using ElementA = int4b_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = int4b_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm75; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) - - unsigned const & A = reinterpret_cast(a); - unsigned const & B = reinterpret_cast(b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - asm volatile("mma.sync.aligned.m8n8k32.row.col.s32.s4.s4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); - -#else - assert(0); -#endif - } -}; - -/// Matrix multiply-add operation: S32 = U4 * S4 + S32 -template <> -struct Mma< - gemm::GemmShape<8,8,32>, - 32, - uint4b_t, - layout::RowMajor, - int4b_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAdd> { - - using Shape = gemm::GemmShape<8,8,32>; - - using ElementA = uint4b_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = int4b_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm75; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) - - unsigned const & A = reinterpret_cast(a); - unsigned const & B = reinterpret_cast(b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - asm volatile("mma.sync.aligned.m8n8k32.row.col.s32.u4.s4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); - -#else - assert(0); -#endif - } -}; - -/// Matrix multiply-add operation: S32 = S4 * U4 + S32 -template <> -struct Mma< - gemm::GemmShape<8,8,32>, - 32, - int4b_t, - layout::RowMajor, - uint4b_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAdd> { - - using Shape = gemm::GemmShape<8,8,32>; - - using ElementA = int4b_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = uint4b_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm75; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) - - unsigned const & A = reinterpret_cast(a); - unsigned const & B = reinterpret_cast(b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - asm volatile("mma.sync.aligned.m8n8k32.row.col.s32.s4.u4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); - -#else - assert(0); -#endif - } -}; - -/// Matrix multiply-add operation: S32 = U4 * U4 + S32 -template <> -struct Mma< - gemm::GemmShape<8,8,32>, - 32, - uint4b_t, - layout::RowMajor, - uint4b_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAdd> { - - using Shape = gemm::GemmShape<8,8,32>; - - using ElementA = uint4b_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = uint4b_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm75; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) - - unsigned const & A = reinterpret_cast(a); - unsigned const & B = reinterpret_cast(b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - asm volatile("mma.sync.aligned.m8n8k32.row.col.s32.u4.u4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" - : "=r"(D[0]), "=r"(D[1]) - : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); - #else - assert(0); + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + CUTLASS_NOT_IMPLEMENTED(); #endif } }; @@ -905,7 +452,7 @@ struct Mma< /// Matrix multiply-add operation: S32 = S4 * S4 + S32 template <> struct Mma< - gemm::GemmShape<8,8,32>, + gemm::GemmShape<8, 8, 32>, 32, int4b_t, layout::RowMajor, @@ -915,7 +462,7 @@ struct Mma< layout::RowMajor, OpMultiplyAddSaturate> { - using Shape = gemm::GemmShape<8,8,32>; + using Shape = gemm::GemmShape<8, 8, 32>; using ElementA = int4b_t; using LayoutA = layout::RowMajor; @@ -952,9 +499,12 @@ struct Mma< asm volatile("mma.sync.aligned.m8n8k32.row.col.satfinite.s32.s4.s4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" : "=r"(D[0]), "=r"(D[1]) : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); - #else - assert(0); + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + CUTLASS_NOT_IMPLEMENTED(); #endif } }; @@ -962,7 +512,7 @@ struct Mma< /// Matrix multiply-add operation: S32 = U4 * S4 + S32 template <> struct Mma< - gemm::GemmShape<8,8,32>, + gemm::GemmShape<8, 8, 32>, 32, uint4b_t, layout::RowMajor, @@ -972,7 +522,7 @@ struct Mma< layout::RowMajor, OpMultiplyAddSaturate> { - using Shape = gemm::GemmShape<8,8,32>; + using Shape = gemm::GemmShape<8, 8, 32>; using ElementA = uint4b_t; using LayoutA = layout::RowMajor; @@ -1009,9 +559,12 @@ struct Mma< asm volatile("mma.sync.aligned.m8n8k32.row.col.satfinite.s32.u4.s4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" : "=r"(D[0]), "=r"(D[1]) : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); - #else - assert(0); + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + CUTLASS_NOT_IMPLEMENTED(); #endif } }; @@ -1019,7 +572,7 @@ struct Mma< /// Matrix multiply-add operation: S32 = S4 * U4 + S32 template <> struct Mma< - gemm::GemmShape<8,8,32>, + gemm::GemmShape<8, 8, 32>, 32, int4b_t, layout::RowMajor, @@ -1029,7 +582,7 @@ struct Mma< layout::RowMajor, OpMultiplyAddSaturate> { - using Shape = gemm::GemmShape<8,8,32>; + using Shape = gemm::GemmShape<8, 8, 32>; using ElementA = int4b_t; using LayoutA = layout::RowMajor; @@ -1066,9 +619,12 @@ struct Mma< asm volatile("mma.sync.aligned.m8n8k32.row.col.satfinite.s32.s4.u4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" : "=r"(D[0]), "=r"(D[1]) : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); - #else - assert(0); + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + CUTLASS_NOT_IMPLEMENTED(); #endif } }; @@ -1076,7 +632,7 @@ struct Mma< /// Matrix multiply-add operation: S32 = U4 * U4 + S32 template <> struct Mma< - gemm::GemmShape<8,8,32>, + gemm::GemmShape<8, 8, 32>, 32, uint4b_t, layout::RowMajor, @@ -1086,7 +642,7 @@ struct Mma< layout::RowMajor, OpMultiplyAddSaturate> { - using Shape = gemm::GemmShape<8,8,32>; + using Shape = gemm::GemmShape<8, 8, 32>; using ElementA = uint4b_t; using LayoutA = layout::RowMajor; @@ -1123,9 +679,12 @@ struct Mma< asm volatile("mma.sync.aligned.m8n8k32.row.col.satfinite.s32.u4.u4.s32 {%0,%1}, {%2}, {%3}, {%4,%5};\n" : "=r"(D[0]), "=r"(D[1]) : "r"(A), "r"(B), "r"(C[0]), "r"(C[1])); - #else - assert(0); + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + CUTLASS_NOT_IMPLEMENTED(); #endif } }; @@ -1176,6 +735,7 @@ struct Mma< ) const { #if defined(CUTLASS_ARCH_MMA_SM75_ENABLED) + #if defined(CUTLASS_ARCH_WMMA_ENABLED) using WmmaFragmentA = nvcuda::wmma::fragment< nvcuda::wmma::matrix_a, @@ -1208,16 +768,18 @@ struct Mma< nvcuda::wmma::bmma_sync(D, A, B, C, nvcuda::wmma::experimental::bmmaBitOpXOR, nvcuda::wmma::experimental::bmmaAccumulateOpPOPC); + #else - assert(0); // WMMA must be supported to issue binary matrix multiply-accumulate instructions. + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + CUTLASS_NOT_IMPLEMENTED(); // WMMA must be supported to issue binary matrix multiply-accumulate instructions. #endif // defined(CUTLASS_ARCH_WMMA_ENABLED) -#else - assert(0); #endif - } }; diff --git a/include/cutlass/arch/mma_sm80.h b/include/cutlass/arch/mma_sm80.h index 5b9f524067..19d78bf20e 100644 --- a/include/cutlass/arch/mma_sm80.h +++ b/include/cutlass/arch/mma_sm80.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -34,11 +34,7 @@ #pragma once -#if defined(__CUDACC_RTC__) #include -#else -#include -#endif #include "cutlass/cutlass.h" #include "mma.h" @@ -53,7 +49,16 @@ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) #define CUTLASS_ARCH_MMA_SM80_ENABLED + +#if (__CUDA_ARCH__ <= 900) +#define CUTLASS_ARCH_MMA_B1_AND_SM80_ENABLED +#endif +#if (__CUDA_ARCH__ <= 890) +#define CUTLASS_ARCH_MMA_B1_XOR_SM80_ENABLED #endif + +#endif + #endif //////////////////////////////////////////////////////////////////////////////// @@ -528,7 +533,7 @@ struct Mma< //////////////////////////////////////////////////////////////////////////////// // -// Matrix Multiply 16816 - S8 input, S32 accumulation +// Matrix Multiply 16816 - S8 input, S32 accumulation - SATURATE // //////////////////////////////////////////////////////////////////////////////// @@ -543,7 +548,7 @@ struct Mma< layout::ColumnMajor, int, layout::RowMajor, - OpMultiplyAdd> { + OpMultiplyAddSaturate> { using Shape = gemm::GemmShape<16,8,16>; @@ -559,8 +564,7 @@ struct Mma< using LayoutC = layout::RowMajor; using FragmentC = Array; - using Operator = OpMultiplyAdd; - + using Operator = OpMultiplyAddSaturate; using ArchTag = arch::Sm80; /// Computes multiply-add @@ -573,6 +577,7 @@ struct Mma< ) const { #if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + uint32_t const *A = reinterpret_cast(&a); uint32_t const &B = reinterpret_cast(b); @@ -580,8 +585,8 @@ struct Mma< int *D = reinterpret_cast(&d); asm volatile( - "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32 {%0,%1,%2,%3}, {%4,%5}, {%6}, " - "{%7,%8,%9,%10};\n" + "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite {%0,%1,%2,%3}, {%4,%5}, " + "{%6}, {%7,%8,%9,%10};\n" : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); @@ -603,7 +608,7 @@ struct Mma< layout::ColumnMajor, int, layout::RowMajor, - OpMultiplyAdd> { + OpMultiplyAddSaturate> { using Shape = gemm::GemmShape<16,8,16>; @@ -619,7 +624,7 @@ struct Mma< using LayoutC = layout::RowMajor; using FragmentC = Array; - using Operator = OpMultiplyAdd; + using Operator = OpMultiplyAddSaturate; using ArchTag = arch::Sm80; /// Computes multiply-add @@ -632,6 +637,7 @@ struct Mma< ) const { #if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + uint32_t const *A = reinterpret_cast(&a); uint32_t const &B = reinterpret_cast(b); @@ -639,8 +645,8 @@ struct Mma< int *D = reinterpret_cast(&d); asm volatile( - "mma.sync.aligned.m16n8k16.row.col.s32.u8.s8.s32 {%0,%1,%2,%3}, {%4,%5}, {%6}, " - "{%7,%8,%9,%10};\n" + "mma.sync.aligned.m16n8k16.row.col.s32.u8.s8.s32.satfinite {%0,%1,%2,%3}, {%4,%5}, " + "{%6}, {%7,%8,%9,%10};\n" : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); @@ -662,7 +668,7 @@ struct Mma< layout::ColumnMajor, int, layout::RowMajor, - OpMultiplyAdd> { + OpMultiplyAddSaturate> { using Shape = gemm::GemmShape<16,8,16>; @@ -678,7 +684,7 @@ struct Mma< using LayoutC = layout::RowMajor; using FragmentC = Array; - using Operator = OpMultiplyAdd; + using Operator = OpMultiplyAddSaturate; using ArchTag = arch::Sm80; /// Computes multiply-add @@ -699,12 +705,12 @@ struct Mma< int *D = reinterpret_cast(&d); asm volatile( - "mma.sync.aligned.m16n8k16.row.col.s32.s8.u8.s32 {%0,%1,%2,%3}, {%4,%5}, {%6}, " - "{%7,%8,%9,%10};\n" + "mma.sync.aligned.m16n8k16.row.col.s32.s8.u8.s32.satfinite {%0,%1,%2,%3}, {%4,%5}, " + "{%6}, {%7,%8,%9,%10};\n" : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - + #else assert(0); #endif @@ -722,7 +728,7 @@ struct Mma< layout::ColumnMajor, int, layout::RowMajor, - OpMultiplyAdd> { + OpMultiplyAddSaturate> { using Shape = gemm::GemmShape<16,8,16>; @@ -738,7 +744,7 @@ struct Mma< using LayoutC = layout::RowMajor; using FragmentC = Array; - using Operator = OpMultiplyAdd; + using Operator = OpMultiplyAddSaturate; using ArchTag = arch::Sm80; /// Computes multiply-add @@ -759,13 +765,12 @@ struct Mma< int *D = reinterpret_cast(&d); asm volatile( - "mma.sync.aligned.m16n8k16.row.col.s32.u8.u8.s32 {%0,%1,%2,%3}, {%4,%5}, {%6}, " - "{%7,%8,%9,%10};\n" + "mma.sync.aligned.m16n8k16.row.col.s32.u8.u8.s32.satfinite {%0,%1,%2,%3}, {%4,%5}, " + "{%6}, {%7,%8,%9,%10};\n" : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - #else assert(0); #endif @@ -774,14 +779,14 @@ struct Mma< //////////////////////////////////////////////////////////////////////////////// // -// Matrix Multiply 16816 - S8 input, S32 accumulation - SATURATE +// Matrix Multiply 16832 - S8 input, S32 accumulation - SATURATE // //////////////////////////////////////////////////////////////////////////////// /// Matrix multiply-add operation: S32 = S8 * S8 + S32 template <> struct Mma< - gemm::GemmShape<16,8,16>, + gemm::GemmShape<16,8,32>, 32, int8_t, layout::RowMajor, @@ -791,15 +796,15 @@ struct Mma< layout::RowMajor, OpMultiplyAddSaturate> { - using Shape = gemm::GemmShape<16,8,16>; + using Shape = gemm::GemmShape<16,8,32>; using ElementA = int8_t; using LayoutA = layout::RowMajor; - using FragmentA = Array; + using FragmentA = Array; using ElementB = int8_t; using LayoutB = layout::ColumnMajor; - using FragmentB = Array; + using FragmentB = Array; using ElementC = int; using LayoutC = layout::RowMajor; @@ -819,18 +824,18 @@ struct Mma< #if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) - uint32_t const *A = reinterpret_cast(&a); - uint32_t const &B = reinterpret_cast(b); + uint32_t const * A = reinterpret_cast(&a); + uint32_t const * B = reinterpret_cast(&b); - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); - asm volatile( - "mma.sync.aligned.m16n8k16.row.col.s32.s8.s8.s32.satfinite {%0,%1,%2,%3}, {%4,%5}, " - "{%6}, {%7,%8,%9,%10};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), - "r"(C[3])); + asm volatile( + "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite {%0,%1,%2,%3}, " + "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); #else assert(0); @@ -841,7 +846,7 @@ struct Mma< /// Matrix multiply-add operation: S32 = U8 * S8 + S32 template <> struct Mma< - gemm::GemmShape<16,8,16>, + gemm::GemmShape<16,8,32>, 32, uint8_t, layout::RowMajor, @@ -851,15 +856,15 @@ struct Mma< layout::RowMajor, OpMultiplyAddSaturate> { - using Shape = gemm::GemmShape<16,8,16>; + using Shape = gemm::GemmShape<16,8,32>; using ElementA = uint8_t; using LayoutA = layout::RowMajor; - using FragmentA = Array; + using FragmentA = Array; using ElementB = int8_t; using LayoutB = layout::ColumnMajor; - using FragmentB = Array; + using FragmentB = Array; using ElementC = int; using LayoutC = layout::RowMajor; @@ -880,17 +885,17 @@ struct Mma< #if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) uint32_t const *A = reinterpret_cast(&a); - uint32_t const &B = reinterpret_cast(b); + uint32_t const *B = reinterpret_cast(&b); int const *C = reinterpret_cast(&c); int *D = reinterpret_cast(&d); asm volatile( - "mma.sync.aligned.m16n8k16.row.col.s32.u8.s8.s32.satfinite {%0,%1,%2,%3}, {%4,%5}, " - "{%6}, {%7,%8,%9,%10};\n" + "mma.sync.aligned.m16n8k32.row.col.s32.u8.s8.s32.satfinite {%0,%1,%2,%3}, " + "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), - "r"(C[3])); + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); #else assert(0); @@ -901,7 +906,7 @@ struct Mma< /// Matrix multiply-add operation: S32 = S8 * U8 + S32 template <> struct Mma< - gemm::GemmShape<16,8,16>, + gemm::GemmShape<16,8,32>, 32, int8_t, layout::RowMajor, @@ -911,15 +916,15 @@ struct Mma< layout::RowMajor, OpMultiplyAddSaturate> { - using Shape = gemm::GemmShape<16,8,16>; + using Shape = gemm::GemmShape<16,8,32>; using ElementA = int8_t; using LayoutA = layout::RowMajor; - using FragmentA = Array; + using FragmentA = Array; using ElementB = uint8_t; using LayoutB = layout::ColumnMajor; - using FragmentB = Array; + using FragmentB = Array; using ElementC = int; using LayoutC = layout::RowMajor; @@ -940,18 +945,18 @@ struct Mma< #if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) uint32_t const *A = reinterpret_cast(&a); - uint32_t const &B = reinterpret_cast(b); + uint32_t const *B = reinterpret_cast(&b); int const *C = reinterpret_cast(&c); int *D = reinterpret_cast(&d); asm volatile( - "mma.sync.aligned.m16n8k16.row.col.s32.s8.u8.s32.satfinite {%0,%1,%2,%3}, {%4,%5}, " - "{%6}, {%7,%8,%9,%10};\n" + "mma.sync.aligned.m16n8k32.row.col.s32.s8.u8.s32.satfinite {%0,%1,%2,%3}, " + "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), - "r"(C[3])); - + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + #else assert(0); #endif @@ -961,7 +966,7 @@ struct Mma< /// Matrix multiply-add operation: S32 = U8 * U8 + S32 template <> struct Mma< - gemm::GemmShape<16,8,16>, + gemm::GemmShape<16,8,32>, 32, uint8_t, layout::RowMajor, @@ -971,15 +976,15 @@ struct Mma< layout::RowMajor, OpMultiplyAddSaturate> { - using Shape = gemm::GemmShape<16,8,16>; + using Shape = gemm::GemmShape<16,8,32>; using ElementA = uint8_t; using LayoutA = layout::RowMajor; - using FragmentA = Array; + using FragmentA = Array; using ElementB = uint8_t; using LayoutB = layout::ColumnMajor; - using FragmentB = Array; + using FragmentB = Array; using ElementC = int; using LayoutC = layout::RowMajor; @@ -1000,17 +1005,17 @@ struct Mma< #if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) uint32_t const *A = reinterpret_cast(&a); - uint32_t const &B = reinterpret_cast(b); + uint32_t const *B = reinterpret_cast(&b); int const *C = reinterpret_cast(&c); int *D = reinterpret_cast(&d); asm volatile( - "mma.sync.aligned.m16n8k16.row.col.s32.u8.u8.s32.satfinite {%0,%1,%2,%3}, {%4,%5}, " - "{%6}, {%7,%8,%9,%10};\n" + "mma.sync.aligned.m16n8k32.row.col.s32.u8.u8.s32.satfinite {%0,%1,%2,%3}, " + "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(B), "r"(C[0]), "r"(C[1]), "r"(C[2]), - "r"(C[3])); + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); #else assert(0); @@ -1020,38 +1025,38 @@ struct Mma< //////////////////////////////////////////////////////////////////////////////// // -// Matrix Multiply 16832 - S8 input, S32 accumulation +// Matrix Multiply 16864 - S4 input, S32 accumulation - SATURATE // //////////////////////////////////////////////////////////////////////////////// -/// Matrix multiply-add operation: S32 = S8 * S8 + S32 +/// Matrix multiply-add operation: S32 = S4 * S4 + S32 template <> struct Mma< - gemm::GemmShape<16,8,32>, + gemm::GemmShape<16, 8, 64>, 32, - int8_t, + cutlass::int4b_t, layout::RowMajor, - int8_t, + cutlass::int4b_t, layout::ColumnMajor, int, layout::RowMajor, - OpMultiplyAdd> { + OpMultiplyAddSaturate> { - using Shape = gemm::GemmShape<16,8,32>; + using Shape = gemm::GemmShape<16, 8, 64>; - using ElementA = int8_t; + using ElementA = cutlass::int4b_t; using LayoutA = layout::RowMajor; - using FragmentA = Array; + using FragmentA = Array; - using ElementB = int8_t; + using ElementB = cutlass::int4b_t; using LayoutB = layout::ColumnMajor; - using FragmentB = Array; + using FragmentB = Array; using ElementC = int; using LayoutC = layout::RowMajor; using FragmentC = Array; - using Operator = OpMultiplyAdd; + using Operator = OpMultiplyAddSaturate; using ArchTag = arch::Sm80; /// Computes multiply-add @@ -1065,53 +1070,57 @@ struct Mma< #if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); + uint32_t const * A = reinterpret_cast(&a); + uint32_t const * B = reinterpret_cast(&b); - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); + int const *C = reinterpret_cast(&c); + int *D = reinterpret_cast(&d); - asm volatile( - "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9}, {%10,%11,%12,%13};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); + asm volatile( + "mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32.satfinite {%0,%1,%2,%3}, " + "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } }; -/// Matrix multiply-add operation: S32 = U8 * S8 + S32 +/// Matrix multiply-add operation: S32 = U4 * S4 + S32 template <> struct Mma< - gemm::GemmShape<16,8,32>, + gemm::GemmShape<16, 8, 64>, 32, - uint8_t, + cutlass::uint4b_t, layout::RowMajor, - int8_t, + cutlass::int4b_t, layout::ColumnMajor, int, layout::RowMajor, - OpMultiplyAdd> { + OpMultiplyAddSaturate> { - using Shape = gemm::GemmShape<16,8,32>; + using Shape = gemm::GemmShape<16, 8, 64>; - using ElementA = uint8_t; + using ElementA = cutlass::uint4b_t; using LayoutA = layout::RowMajor; - using FragmentA = Array; + using FragmentA = Array; - using ElementB = int8_t; + using ElementB = cutlass::int4b_t; using LayoutB = layout::ColumnMajor; - using FragmentB = Array; + using FragmentB = Array; using ElementC = int; using LayoutC = layout::RowMajor; using FragmentC = Array; - using Operator = OpMultiplyAdd; + using Operator = OpMultiplyAddSaturate; using ArchTag = arch::Sm80; /// Computes multiply-add @@ -1132,46 +1141,50 @@ struct Mma< int *D = reinterpret_cast(&d); asm volatile( - "mma.sync.aligned.m16n8k32.row.col.s32.u8.s8.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9}, {%10,%11,%12,%13};\n" + "mma.sync.aligned.m16n8k64.row.col.s32.u4.s4.s32.satfinite {%0,%1,%2,%3}, " + "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } }; -/// Matrix multiply-add operation: S32 = S8 * U8 + S32 +/// Matrix multiply-add operation: S32 = S4 * U4 + S32 template <> struct Mma< - gemm::GemmShape<16,8,32>, + gemm::GemmShape<16, 8, 64>, 32, - int8_t, + cutlass::int4b_t, layout::RowMajor, - uint8_t, + cutlass::uint4b_t, layout::ColumnMajor, int, layout::RowMajor, - OpMultiplyAdd> { + OpMultiplyAddSaturate> { - using Shape = gemm::GemmShape<16,8,32>; + using Shape = gemm::GemmShape<16, 8, 64>; - using ElementA = int8_t; + using ElementA = cutlass::int4b_t; using LayoutA = layout::RowMajor; - using FragmentA = Array; + using FragmentA = Array; - using ElementB = uint8_t; + using ElementB = cutlass::uint4b_t; using LayoutB = layout::ColumnMajor; - using FragmentB = Array; + using FragmentB = Array; using ElementC = int; using LayoutC = layout::RowMajor; using FragmentC = Array; - using Operator = OpMultiplyAdd; + using Operator = OpMultiplyAddSaturate; using ArchTag = arch::Sm80; /// Computes multiply-add @@ -1192,46 +1205,50 @@ struct Mma< int *D = reinterpret_cast(&d); asm volatile( - "mma.sync.aligned.m16n8k32.row.col.s32.s8.u8.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9}, {%10,%11,%12,%13};\n" + "mma.sync.aligned.m16n8k64.row.col.s32.s4.u4.s32.satfinite {%0,%1,%2,%3}, " + "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } }; -/// Matrix multiply-add operation: S32 = U8 * U8 + S32 +/// Matrix multiply-add operation: S32 = U4 * U4 + S32 template <> struct Mma< - gemm::GemmShape<16,8,32>, + gemm::GemmShape<16, 8, 64>, 32, - uint8_t, + cutlass::uint4b_t, layout::RowMajor, - uint8_t, + cutlass::uint4b_t, layout::ColumnMajor, int, layout::RowMajor, - OpMultiplyAdd> { + OpMultiplyAddSaturate> { - using Shape = gemm::GemmShape<16,8,32>; + using Shape = gemm::GemmShape<16, 8, 64>; - using ElementA = uint8_t; + using ElementA = cutlass::uint4b_t; using LayoutA = layout::RowMajor; - using FragmentA = Array; + using FragmentA = Array; - using ElementB = uint8_t; + using ElementB = cutlass::uint4b_t; using LayoutB = layout::ColumnMajor; - using FragmentB = Array; + using FragmentB = Array; using ElementC = int; using LayoutC = layout::RowMajor; using FragmentC = Array; - using Operator = OpMultiplyAdd; + using Operator = OpMultiplyAddSaturate; using ArchTag = arch::Sm80; /// Computes multiply-add @@ -1252,725 +1269,56 @@ struct Mma< int *D = reinterpret_cast(&d); asm volatile( - "mma.sync.aligned.m16n8k32.row.col.s32.u8.u8.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9}, {%10,%11,%12,%13};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - -#else - assert(0); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////// -// -// Matrix Multiply 16832 - S8 input, S32 accumulation - SATURATE -// -//////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation: S32 = S8 * S8 + S32 -template <> -struct Mma< - gemm::GemmShape<16,8,32>, - 32, - int8_t, - layout::RowMajor, - int8_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAddSaturate> { - - using Shape = gemm::GemmShape<16,8,32>; - - using ElementA = int8_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = int8_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm80; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) - - uint32_t const * A = reinterpret_cast(&a); - uint32_t const * B = reinterpret_cast(&b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - asm volatile( - "mma.sync.aligned.m16n8k32.row.col.s32.s8.s8.s32.satfinite {%0,%1,%2,%3}, " - "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - -#else - assert(0); -#endif - } -}; - -/// Matrix multiply-add operation: S32 = U8 * S8 + S32 -template <> -struct Mma< - gemm::GemmShape<16,8,32>, - 32, - uint8_t, - layout::RowMajor, - int8_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAddSaturate> { - - using Shape = gemm::GemmShape<16,8,32>; - - using ElementA = uint8_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = int8_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAddSaturate; - using ArchTag = arch::Sm80; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - asm volatile( - "mma.sync.aligned.m16n8k32.row.col.s32.u8.s8.s32.satfinite {%0,%1,%2,%3}, " - "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - -#else - assert(0); -#endif - } -}; - -/// Matrix multiply-add operation: S32 = S8 * U8 + S32 -template <> -struct Mma< - gemm::GemmShape<16,8,32>, - 32, - int8_t, - layout::RowMajor, - uint8_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAddSaturate> { - - using Shape = gemm::GemmShape<16,8,32>; - - using ElementA = int8_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = uint8_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm80; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - asm volatile( - "mma.sync.aligned.m16n8k32.row.col.s32.s8.u8.s32.satfinite {%0,%1,%2,%3}, " - "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - -#else - assert(0); -#endif - } -}; - -/// Matrix multiply-add operation: S32 = U8 * U8 + S32 -template <> -struct Mma< - gemm::GemmShape<16,8,32>, - 32, - uint8_t, - layout::RowMajor, - uint8_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAddSaturate> { - - using Shape = gemm::GemmShape<16,8,32>; - - using ElementA = uint8_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = uint8_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAddSaturate; - using ArchTag = arch::Sm80; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - asm volatile( - "mma.sync.aligned.m16n8k32.row.col.s32.u8.u8.s32.satfinite {%0,%1,%2,%3}, " - "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - -#else - assert(0); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////// -// -// Matrix Multiply 16864 - S4 input, S32 accumulation -// -//////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation: S32 = S4 * S4 + S32 -template <> -struct Mma< - gemm::GemmShape<16, 8, 64>, - 32, - cutlass::int4b_t, - layout::RowMajor, - cutlass::int4b_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAdd> { - - using Shape = gemm::GemmShape<16, 8, 64>; - - using ElementA = cutlass::int4b_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = cutlass::int4b_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm80; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - asm volatile( - "mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9}, {%10,%11,%12,%13};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - -#else - assert(0); -#endif - } -}; - -/// Matrix multiply-add operation: S32 = U4 * S4 + S32 -template <> -struct Mma< - gemm::GemmShape<16, 8, 64>, - 32, - cutlass::uint4b_t, - layout::RowMajor, - cutlass::int4b_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAdd> { - - using Shape = gemm::GemmShape<16, 8, 64>; - - using ElementA = cutlass::uint4b_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = cutlass::int4b_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm80; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - asm volatile( - "mma.sync.aligned.m16n8k64.row.col.s32.u4.s4.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9}, {%10,%11,%12,%13};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - -#else - assert(0); -#endif - } -}; - -/// Matrix multiply-add operation: S32 = S4 * U4 + S32 -template <> -struct Mma< - gemm::GemmShape<16, 8, 64>, - 32, - cutlass::int4b_t, - layout::RowMajor, - cutlass::uint4b_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAdd> { - - using Shape = gemm::GemmShape<16, 8, 64>; - - using ElementA = cutlass::int4b_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = cutlass::uint4b_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm80; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - asm volatile( - "mma.sync.aligned.m16n8k64.row.col.s32.s4.u4.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9}, {%10,%11,%12,%13};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - -#else - assert(0); -#endif - } -}; - -/// Matrix multiply-add operation: S32 = U4 * U4 + S32 -template <> -struct Mma< - gemm::GemmShape<16, 8, 64>, - 32, - cutlass::uint4b_t, - layout::RowMajor, - cutlass::uint4b_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAdd> { - - using Shape = gemm::GemmShape<16, 8, 64>; - - using ElementA = cutlass::uint4b_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = cutlass::uint4b_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm80; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - asm volatile( - "mma.sync.aligned.m16n8k64.row.col.s32.u4.u4.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9}, {%10,%11,%12,%13};\n" + "mma.sync.aligned.m16n8k64.row.col.s32.u4.u4.s32.satfinite {%0,%1,%2,%3}, " + "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } }; - //////////////////////////////////////////////////////////////////////////////// // -// Matrix Multiply 16864 - S4 input, S32 accumulation - SATURATE +// Matrix Multiply 168256 - B1 input, S32 accumulation - AND,POPC // //////////////////////////////////////////////////////////////////////////////// -/// Matrix multiply-add operation: S32 = S4 * S4 + S32 -template <> -struct Mma< - gemm::GemmShape<16, 8, 64>, - 32, - cutlass::int4b_t, - layout::RowMajor, - cutlass::int4b_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAddSaturate> { - - using Shape = gemm::GemmShape<16, 8, 64>; - - using ElementA = cutlass::int4b_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = cutlass::int4b_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm80; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) - - uint32_t const * A = reinterpret_cast(&a); - uint32_t const * B = reinterpret_cast(&b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - asm volatile( - "mma.sync.aligned.m16n8k64.row.col.s32.s4.s4.s32.satfinite {%0,%1,%2,%3}, " - "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - -#else - assert(0); -#endif - } -}; - -/// Matrix multiply-add operation: S32 = U4 * S4 + S32 -template <> -struct Mma< - gemm::GemmShape<16, 8, 64>, - 32, - cutlass::uint4b_t, - layout::RowMajor, - cutlass::int4b_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAddSaturate> { - - using Shape = gemm::GemmShape<16, 8, 64>; - - using ElementA = cutlass::uint4b_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = cutlass::int4b_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAddSaturate; - using ArchTag = arch::Sm80; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - asm volatile( - "mma.sync.aligned.m16n8k64.row.col.s32.u4.s4.s32.satfinite {%0,%1,%2,%3}, " - "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - -#else - assert(0); -#endif - } -}; - -/// Matrix multiply-add operation: S32 = S4 * U4 + S32 -template <> -struct Mma< - gemm::GemmShape<16, 8, 64>, - 32, - cutlass::int4b_t, - layout::RowMajor, - cutlass::uint4b_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAddSaturate> { - - using Shape = gemm::GemmShape<16, 8, 64>; - - using ElementA = cutlass::int4b_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = cutlass::uint4b_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm80; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c - ) const { - -#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - asm volatile( - "mma.sync.aligned.m16n8k64.row.col.s32.s4.u4.s32.satfinite {%0,%1,%2,%3}, " - "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); - -#else - assert(0); -#endif - } -}; - -/// Matrix multiply-add operation: S32 = U4 * U4 + S32 +/// Matrix multiply-add operation: S32 = B1 & B1 + S32 template <> struct Mma< - gemm::GemmShape<16, 8, 64>, + gemm::GemmShape<16,8,256>, 32, - cutlass::uint4b_t, + cutlass::uint1b_t, layout::RowMajor, - cutlass::uint4b_t, + cutlass::uint1b_t, layout::ColumnMajor, - int, + int32_t, layout::RowMajor, - OpMultiplyAddSaturate> { + OpAndPopc> { - using Shape = gemm::GemmShape<16, 8, 64>; + using Shape = gemm::GemmShape<16,8,256>; - using ElementA = cutlass::uint4b_t; + using ElementA = cutlass::uint1b_t; using LayoutA = layout::RowMajor; - using FragmentA = Array; + using FragmentA = Array; - using ElementB = cutlass::uint4b_t; + using ElementB = cutlass::uint1b_t; using LayoutB = layout::ColumnMajor; - using FragmentB = Array; + using FragmentB = Array; - using ElementC = int; + using ElementC = int32_t; using LayoutC = layout::RowMajor; - using FragmentC = Array; + using FragmentC = Array; - using Operator = OpMultiplyAddSaturate; + using Operator = OpAndPopc; using ArchTag = arch::Sm80; /// Computes multiply-add @@ -1982,7 +1330,7 @@ struct Mma< FragmentC const &c ) const { -#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) +#if defined(CUTLASS_ARCH_MMA_B1_AND_SM80_ENABLED) uint32_t const *A = reinterpret_cast(&a); uint32_t const *B = reinterpret_cast(&b); @@ -1991,13 +1339,18 @@ struct Mma< int *D = reinterpret_cast(&d); asm volatile( - "mma.sync.aligned.m16n8k64.row.col.s32.u4.u4.s32.satfinite {%0,%1,%2,%3}, " - "{%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.and.popc {%0,%1,%2,%3}, " + "{%4,%5,%6,%7}, " + "{%8,%9}, {%10,%11,%12,%13};\n" : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } @@ -2042,7 +1395,7 @@ struct Mma< FragmentC const &c ) const { -#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) +#if defined(CUTLASS_ARCH_MMA_B1_AND_SM80_ENABLED) uint32_t const *A = reinterpret_cast(&a); uint32_t const *B = reinterpret_cast(&b); @@ -2059,6 +1412,10 @@ struct Mma< "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3])); #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } @@ -2109,13 +1466,14 @@ struct Mma< FragmentC const &c ) const { -#if defined(CUTLASS_ARCH_MMA_SM80_ENABLED) +#if defined(CUTLASS_ARCH_MMA_B1_XOR_SM80_ENABLED) uint32_t const *A = reinterpret_cast(&a); uint32_t const *B = reinterpret_cast(&b); int const *C = reinterpret_cast(&c); int *D = reinterpret_cast(&d); + asm volatile( "mma.sync.aligned.m16n8k256.row.col.s32.b1.b1.s32.xor.popc {%0,%1,%2,%3}, " "{%4,%5,%6,%7}, " @@ -2126,9 +1484,13 @@ struct Mma< #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); - -#endif // defined(CUTLASS_ARCH_MMA_SM80_ENABLED) + +#endif // defined(CUTLASS_ARCH_MMA_B1_XOR_SM80_ENABLED) } }; @@ -2136,5 +1498,4 @@ struct Mma< } // namespace arch } // namespace cutlass - ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/arch/mma_sm89.h b/include/cutlass/arch/mma_sm89.h new file mode 100644 index 0000000000..d8a75b6623 --- /dev/null +++ b/include/cutlass/arch/mma_sm89.h @@ -0,0 +1,363 @@ +/*************************************************************************************************** + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Matrix multiply-accumulate specialzied for SM89 +*/ + +#pragma once + +#include + +#include "cutlass/cutlass.h" +#include "mma.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" + +//////////////////////////////////////////////////////////////////////////////// + +#if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4) + +# define CUTLASS_ARCH_MMA_SM89_SUPPORTED 1 +#endif + +#if defined(CUTLASS_ARCH_MMA_SM89_SUPPORTED) && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 890) +# define CUTLASS_ARCH_MMA_SM89_ENABLED +#endif + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace arch { + +//////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +// Whether the Mma uses as SM89 staged accumulation policy +template +static constexpr bool is_sm89_staged_policy_v = + ( + // ElementA must be FP8 + platform::is_same::value || + platform::is_same::value + ) && + ( + // ElementB must be FP8 + platform::is_same::value || + platform::is_same::value + ) && + ( + // The instruction shape must be 16x8x32 + Operator::ArchMmaOperator::Shape::kM == 16 && + Operator::ArchMmaOperator::Shape::kN == 8 && + Operator::ArchMmaOperator::Shape::kK == 32 + ) && + ( + // The operator must be OpMultiplyAdd (default) + platform::is_same::value + ); +} // namespace detail + +//////////////////////////////////////////////////////////////////////////////// + +//////////////////////////////////////////////////////////////////////////////// +// +// Matrix Multiply 16832 - Float {E4M3, E5M2}, FP32 accumulation +// +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation - F32 = fe4m3 * fe4m3 + F32 +template +struct Mma< + gemm::GemmShape<16, 8, 32>, + 32, + cutlass::float_e4m3_t, + layout::RowMajor, + cutlass::float_e4m3_t, + layout::ColumnMajor, + float, + layout::RowMajor, + Operator_> { + static_assert(platform::is_same::value || + platform::is_same::value, + "Invalid operator for SM89 FP8 instruction"); + + using Shape = gemm::GemmShape<16, 8, 32>; + + using ElementA = cutlass::float_e4m3_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::float_e4m3_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = float; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = Operator_; + using ArchTag = arch::Sm89; + + CUTLASS_HOST_DEVICE + void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, + FragmentC const &c) const { + +#if defined(CUTLASS_ARCH_MMA_SM89_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + float const *C = reinterpret_cast(&c); + float *D = reinterpret_cast(&d); + + asm( + "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e4m3.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : + "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), + "r"(B[0]), "r"(B[1]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]) + ); + +#else + + CUTLASS_UNUSED(d); + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_NOT_IMPLEMENTED(); + +#endif + } +}; + +/// Matrix multiply-add operation - F32 = fe4m3 * fe5m2 + F32 +template +struct Mma< + gemm::GemmShape<16, 8, 32>, + 32, + cutlass::float_e4m3_t, + layout::RowMajor, + cutlass::float_e5m2_t, + layout::ColumnMajor, + float, + layout::RowMajor, + Operator_> { + static_assert(platform::is_same::value || + platform::is_same::value, + "Invalid operator for SM89 FP8 instruction"); + + using Shape = gemm::GemmShape<16, 8, 32>; + + using ElementA = cutlass::float_e4m3_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::float_e5m2_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = float; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = Operator_; + using ArchTag = arch::Sm89; + + CUTLASS_HOST_DEVICE + void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, + FragmentC const &c) const { + +#if defined(CUTLASS_ARCH_MMA_SM89_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + float const *C = reinterpret_cast(&c); + float *D = reinterpret_cast(&d); + + asm( + "mma.sync.aligned.m16n8k32.row.col.f32.e4m3.e5m2.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : + "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), + "r"(B[0]), "r"(B[1]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]) + ); + +#else + + CUTLASS_UNUSED(d); + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_NOT_IMPLEMENTED(); + +#endif + } +}; + +/// Matrix multiply-add operation - F32 = fe5m2 * fe4m3 + F32 +template +struct Mma< + gemm::GemmShape<16, 8, 32>, + 32, + cutlass::float_e5m2_t, + layout::RowMajor, + cutlass::float_e4m3_t, + layout::ColumnMajor, + float, + layout::RowMajor, + Operator_> { + static_assert(platform::is_same::value || + platform::is_same::value, + "Invalid operator for SM89 FP8 instruction"); + + using Shape = gemm::GemmShape<16, 8, 32>; + + using ElementA = cutlass::float_e5m2_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::float_e4m3_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = float; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = Operator_; + using ArchTag = arch::Sm89; + + CUTLASS_HOST_DEVICE + void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, + FragmentC const &c) const { + +#if defined(CUTLASS_ARCH_MMA_SM89_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + float const *C = reinterpret_cast(&c); + float *D = reinterpret_cast(&d); + + asm( + "mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e4m3.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : + "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), + "r"(B[0]), "r"(B[1]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]) + ); + +#else + + CUTLASS_UNUSED(d); + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_NOT_IMPLEMENTED(); + +#endif + } +}; + +/// Matrix multiply-add operation - F32 = fe5m2 * fe5m2 + F32 +template +struct Mma< + gemm::GemmShape<16, 8, 32>, + 32, + cutlass::float_e5m2_t, + layout::RowMajor, + cutlass::float_e5m2_t, + layout::ColumnMajor, + float, + layout::RowMajor, + Operator_> { + static_assert(platform::is_same::value || + platform::is_same::value, + "Invalid operator for SM89 FP8 instruction"); + + using Shape = gemm::GemmShape<16, 8, 32>; + + using ElementA = cutlass::float_e5m2_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::float_e5m2_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = float; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = Operator_; + using ArchTag = arch::Sm89; + + CUTLASS_HOST_DEVICE + void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, + FragmentC const &c) const { + +#if defined(CUTLASS_ARCH_MMA_SM89_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + float const *C = reinterpret_cast(&c); + float *D = reinterpret_cast(&d); + + asm( + "mma.sync.aligned.m16n8k32.row.col.f32.e5m2.e5m2.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9}, {%10,%11,%12,%13};\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : + "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), + "r"(B[0]), "r"(B[1]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]) + ); + +#else + + CUTLASS_UNUSED(d); + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_NOT_IMPLEMENTED(); + +#endif + } +}; + +} // namespace arch +} // namespace cutlass diff --git a/include/cutlass/arch/mma_sm90.h b/include/cutlass/arch/mma_sm90.h new file mode 100644 index 0000000000..16108f0a1e --- /dev/null +++ b/include/cutlass/arch/mma_sm90.h @@ -0,0 +1,241 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Matrix multiply +*/ + +#pragma once + +#include + +#include "mma.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/config.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace arch { + +//////////////////////////////////////////////////////////////////////////////// +/// Matrix Multiply-Add 16x8x4 fp64 +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: F64 = F64 * F64 + F64 +template <> +struct Mma< + gemm::GemmShape<16,8,4>, + 32, + double, + layout::RowMajor, + double, + layout::ColumnMajor, + double, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<16,8,4>; + + using ElementA = double; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = double; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = double; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + + using ArchTag = arch::Sm90; + + CUTLASS_HOST_DEVICE + void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, + FragmentC const &c) const { + +#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) + + double const *A = reinterpret_cast(&a); + double const *B = reinterpret_cast(&b); + + double const *C = reinterpret_cast(&c); + double *D = reinterpret_cast(&d); + + asm volatile("mma.sync.aligned.m16n8k4.row.col.f64.f64.f64.f64.rn {%0, %1, %2, %3}, {%4, %5}, {%6}, {%7, %8, %9, %10};\n" + : "=d"(D[0]), "=d"(D[1]), "=d"(D[2]), "=d"(D[3]) + : "d"(A[0]), "d"(A[1]), + "d"(B[0]), + "d"(C[0]), "d"(C[1]), "d"(C[2]), "d"(C[3])); + +#else + CUTLASS_UNUSED(d); + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Matrix Multiply-Add 16x8x8 fp64 +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: F64 = F64 * F64 + F64 +template <> +struct Mma< + gemm::GemmShape<16,8,8>, + 32, + double, + layout::RowMajor, + double, + layout::ColumnMajor, + double, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<16,8,8>; + + using ElementA = double; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = double; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = double; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + + using ArchTag = arch::Sm90; + + CUTLASS_HOST_DEVICE + void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, + FragmentC const &c) const { + +#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) + + double const *A = reinterpret_cast(&a); + double const *B = reinterpret_cast(&b); + + double const *C = reinterpret_cast(&c); + double *D = reinterpret_cast(&d); + + asm volatile("mma.sync.aligned.m16n8k8.row.col.f64.f64.f64.f64 {%0, %1, %2, %3}, {%4, %5, %6, %7}, {%8, %9}, {%10, %11, %12, %13};\n" + : "=d"(D[0]), "=d"(d[1]), "=d"(d[2]), "=d"(d[3]) + : "d"(A[0]), "d"(A[1]), "d"(A[2]), "d"(A[3]), + "d"(B[0]), "d"(B[1]), + "d"(C[0]), "d"(C[1]), "d"(C[2]), "d"(C[3])); + +#else + + CUTLASS_UNUSED(d); + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +//////////////////////////////////////////////////////////////////////////////// +/// Matrix Multiply-Add 16x8x16 fp64 +//////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: F64 = F64 * F64 + F64 +template <> +struct Mma< + gemm::GemmShape<16,8,16>, + 32, + double, + layout::RowMajor, + double, + layout::ColumnMajor, + double, + layout::RowMajor, + OpMultiplyAdd> { + + using Shape = gemm::GemmShape<16,8,16>; + + using ElementA = double; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = double; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = double; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using Operator = OpMultiplyAdd; + + using ArchTag = arch::Sm90; + + CUTLASS_HOST_DEVICE + void operator()(FragmentC &d, FragmentA const &a, FragmentB const &b, + FragmentC const &c) const { + +#if defined(CUTLASS_ARCH_MMA_SM90_F64_MMA_ENABLED) + + double const *A = reinterpret_cast(&a); + double const *B = reinterpret_cast(&b); + + double const *C = reinterpret_cast(&c); + double *D = reinterpret_cast(&d); + + asm volatile("mma.sync.aligned.m16n8k16.row.col.f64.f64.f64.f64 {%0, %1, %2, %3}, {%4, %5, %6, %7, %8, %9, %10, %11}, {%12, %13, %14, %15}, {%16, %17, %18, %19};\n" + : "=d"(D[0]), "=d"(D[1]), "=d"(D[2]), "=d"(D[3]) + : "d"(A[0]), "d"(A[2]), "d"(A[2]), "d"(A[3]), "d"(A[4]), "d"(A[5]), "d"(A[6]), "d"(A[7]), + "d"(B[0]), "d"(B[1]), "d"(B[2]), "d"(B[3]), + "d"(C[0]), "d"(C[1]), "d"(C[2]), "d"(C[3])); + +#else + CUTLASS_NOT_IMPLEMENTED(); +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace arch +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/include/cutlass/arch/mma_sparse_sm80.h b/include/cutlass/arch/mma_sparse_sm80.h index e22d600623..ed2a5ad019 100644 --- a/include/cutlass/arch/mma_sparse_sm80.h +++ b/include/cutlass/arch/mma_sparse_sm80.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -35,11 +35,7 @@ #pragma once -#if defined(__CUDACC_RTC__) #include -#else -#include -#endif #include "mma.h" #include "cutlass/layout/matrix.h" @@ -54,6 +50,7 @@ #if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800)) #define CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED #endif + #endif ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -121,6 +118,27 @@ struct SparseMma< uint32_t const *C = reinterpret_cast(&c); uint32_t *D = reinterpret_cast(&d); +#if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) + if (id2 == 0) { + asm volatile( + "mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f16.f16.f16.f16 {%0,%1}, " + "{%2,%3,%4,%5}, {%6,%7,%8,%9}, {%10,%11}, %12, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(B[2]), "r"(B[3]), "r"(C[0]), "r"(C[1]), "r"(E)); + } + else if (id2 == 1) { + asm volatile( + "mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f16.f16.f16.f16 {%0,%1}, " + "{%2,%3,%4,%5}, {%6,%7,%8,%9}, {%10,%11}, %12, 0x1;\n" + : "=r"(D[0]), "=r"(D[1]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(B[2]), "r"(B[3]), "r"(C[0]), "r"(C[1]), "r"(E)); + } + else { + assert(0); + } +#else if (id2 == 0) { asm volatile( "mma.sp.sync.aligned.m16n8k32.row.col.f16.f16.f16.f16 {%0,%1}, " @@ -140,7 +158,13 @@ struct SparseMma< else { assert(0); } +#endif + #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } @@ -200,6 +224,29 @@ struct SparseMma< float const *C = reinterpret_cast(&c); float *D = reinterpret_cast(&d); +#if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) + if (id2 == 0) { + asm volatile( + "mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, " + "{%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(B[2]), "r"(B[3]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), + "r"(E)); + } + else if (id2 == 1) { + asm volatile( + "mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, " + "{%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x1;\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), + "r"(B[2]), "r"(B[3]), "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), + "r"(E)); + } + else { + assert(0); + } +#else if (id2 == 0) { asm volatile( "mma.sp.sync.aligned.m16n8k32.row.col.f32.f16.f16.f32 {%0,%1,%2,%3}, " @@ -222,8 +269,13 @@ struct SparseMma< assert(0); } -#else +#endif +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } @@ -276,26 +328,50 @@ struct SparseMma, 32, bfloat16_t, layout::RowMajor, float const *C = reinterpret_cast(&c); float *D = reinterpret_cast(&d); +#if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) if (id2 == 0) { - asm volatile( - "mma.sp.sync.aligned.m16n8k32.row.col.f32.bf16.bf16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" - : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); + asm volatile( + "mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); } else if (id2 == 1) { - asm volatile( - "mma.sp.sync.aligned.m16n8k32.row.col.f32.bf16.bf16.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x1;\n" - : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); + asm volatile( + "mma.sp::ordered_metadata.sync.aligned.m16n8k32.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x1;\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); } else { - assert(0); + assert(0); + } +#else + if (id2 == 0) { + asm volatile( + "mma.sp.sync.aligned.m16n8k32.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); + } else if (id2 == 1) { + asm volatile( + "mma.sp.sync.aligned.m16n8k32.row.col.f32.bf16.bf16.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x1;\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); + } else { + assert(0); } +#endif #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } @@ -348,26 +424,50 @@ struct SparseMma, 32, tfloat32_t, layout::RowMajor, float const *C = reinterpret_cast(&c); float *D = reinterpret_cast(&d); +#if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) if (id2 == 0) { - asm volatile( - "mma.sp.sync.aligned.m16n8k16.row.col.f32.tf32.tf32.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" - : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); + asm volatile( + "mma.sp::ordered_metadata.sync.aligned.m16n8k16.row.col.f32.tf32.tf32.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); } else if (id2 == 1) { - asm volatile( - "mma.sp.sync.aligned.m16n8k16.row.col.f32.tf32.tf32.f32 " - "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x1;\n" - : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); + asm volatile( + "mma.sp::ordered_metadata.sync.aligned.m16n8k16.row.col.f32.tf32.tf32.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x1;\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); } else { - assert(0); + assert(0); + } +#else + if (id2 == 0) { + asm volatile( + "mma.sp.sync.aligned.m16n8k16.row.col.f32.tf32.tf32.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); + } else if (id2 == 1) { + asm volatile( + "mma.sp.sync.aligned.m16n8k16.row.col.f32.tf32.tf32.f32 " + "{%0,%1,%2,%3}, {%4,%5,%6,%7}, {%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x1;\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); + } else { + assert(0); } +#endif #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } @@ -375,7 +475,7 @@ struct SparseMma, 32, tfloat32_t, layout::RowMajor, //////////////////////////////////////////////////////////////////////////////// // -// Sparse Matrix Multiply 16864 - S8 input, S32 accumulation +// Sparse Matrix Multiply 16864 - S8 input, S32 accumulation - SATURATE // //////////////////////////////////////////////////////////////////////////////// @@ -390,7 +490,7 @@ struct SparseMma< layout::ColumnMajor, int, layout::RowMajor, - OpMultiplyAdd, + OpMultiplyAddSaturate, SPFormatType::Thread> { using Shape = gemm::GemmShape<16,8,64>; @@ -409,7 +509,7 @@ struct SparseMma< using FragmentE = uint32_t; - using Operator = OpMultiplyAdd; + using Operator = OpMultiplyAddSaturate; using ArchTag = arch::Sm80; static int const kSparse = 2; @@ -437,18 +537,35 @@ struct SparseMma< int const *C = reinterpret_cast(&c); int *D = reinterpret_cast(&d); - if (id2 == 0) - asm volatile( - "mma.sp.sync.aligned.m16n8k64.row.col.s32.s8.s8.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); - else - assert(0); - +#if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) + if (id2 == 0) { + asm volatile( + "mma.sp::ordered_metadata.sync.aligned.m16n8k64.row.col.s32.s8.s8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + } else { + assert(0); + } #else + if (id2 == 0) { + asm volatile( + "mma.sp.sync.aligned.m16n8k64.row.col.s32.s8.s8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + } else { + assert(0); + } +#endif +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } @@ -465,7 +582,7 @@ struct SparseMma< layout::ColumnMajor, int, layout::RowMajor, - OpMultiplyAdd, + OpMultiplyAddSaturate, SPFormatType::Thread> { using Shape = gemm::GemmShape<16,8,64>; @@ -484,7 +601,7 @@ struct SparseMma< using FragmentE = uint32_t; - using Operator = OpMultiplyAdd; + using Operator = OpMultiplyAddSaturate; using ArchTag = arch::Sm80; static int const kSparse = 2; @@ -512,18 +629,36 @@ struct SparseMma< int const *C = reinterpret_cast(&c); int *D = reinterpret_cast(&d); - if (id2 == 0) - asm volatile( - "mma.sp.sync.aligned.m16n8k64.row.col.s32.s8.u8.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); - else - assert(0); +#if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) + if (id2 == 0) { + asm volatile( + "mma.sp::ordered_metadata.sync.aligned.m16n8k64.row.col.s32.s8.u8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + } else { + assert(0); + } +#else + if (id2 == 0) { + asm volatile( + "mma.sp.sync.aligned.m16n8k64.row.col.s32.s8.u8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + } else { + assert(0); + } +#endif #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } @@ -540,7 +675,7 @@ struct SparseMma< layout::ColumnMajor, int, layout::RowMajor, - OpMultiplyAdd, + OpMultiplyAddSaturate, SPFormatType::Thread> { using Shape = gemm::GemmShape<16,8,64>; @@ -559,7 +694,7 @@ struct SparseMma< using FragmentE = uint32_t; - using Operator = OpMultiplyAdd; + using Operator = OpMultiplyAddSaturate; using ArchTag = arch::Sm80; static int const kSparse = 2; @@ -587,18 +722,35 @@ struct SparseMma< int const *C = reinterpret_cast(&c); int *D = reinterpret_cast(&d); - if (id2 == 0) - asm volatile( - "mma.sp.sync.aligned.m16n8k64.row.col.s32.u8.s8.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); - else - assert(0); - +#if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) + if (id2 == 0) { + asm volatile( + "mma.sp::ordered_metadata.sync.aligned.m16n8k64.row.col.s32.u8.s8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + } else { + assert(0); + } #else + if (id2 == 0) { + asm volatile( + "mma.sp.sync.aligned.m16n8k64.row.col.s32.u8.s8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + } else { + assert(0); + } +#endif +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } @@ -615,7 +767,7 @@ struct SparseMma< layout::ColumnMajor, int, layout::RowMajor, - OpMultiplyAdd, + OpMultiplyAddSaturate, SPFormatType::Thread> { using Shape = gemm::GemmShape<16,8,64>; @@ -634,7 +786,7 @@ struct SparseMma< using FragmentE = uint32_t; - using Operator = OpMultiplyAdd; + using Operator = OpMultiplyAddSaturate; using ArchTag = arch::Sm80; static int const kSparse = 2; @@ -662,18 +814,35 @@ struct SparseMma< int const *C = reinterpret_cast(&c); int *D = reinterpret_cast(&d); - if (id2 == 0) - asm volatile( - "mma.sp.sync.aligned.m16n8k64.row.col.s32.u8.u8.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); - else - assert(0); - +#if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) + if (id2 == 0) { + asm volatile( + "mma.sp::ordered_metadata.sync.aligned.m16n8k64.row.col.s32.u8.u8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + } else { + assert(0); + } #else + if (id2 == 0) { + asm volatile( + "mma.sp.sync.aligned.m16n8k64.row.col.s32.u8.u8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + } else { + assert(0); + } +#endif +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } @@ -681,33 +850,33 @@ struct SparseMma< //////////////////////////////////////////////////////////////////////////////// // -// Sparse Matrix Multiply 16864 - S8 input, S32 accumulation - SATURATE +// Sparse Matrix Multiply 168128 - S4 input, S32 accumulation - SATURATE // //////////////////////////////////////////////////////////////////////////////// -/// Matrix multiply-add operation: S32 = S8 * S8 + S32 +/// Matrix multiply-add operation: S32 = S4 * S4 + S32 template <> struct SparseMma< - gemm::GemmShape<16,8,64>, + gemm::GemmShape<16,8,128>, 32, - int8_t, + cutlass::int4b_t, layout::RowMajor, - int8_t, + cutlass::int4b_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate, SPFormatType::Thread> { - using Shape = gemm::GemmShape<16,8,64>; + using Shape = gemm::GemmShape<16,8,128>; - using ElementA = int8_t; + using ElementA = cutlass::int4b_t; using LayoutA = layout::RowMajor; - using FragmentA = Array; + using FragmentA = Array; - using ElementB = int8_t; + using ElementB = cutlass::int4b_t; using LayoutB = layout::ColumnMajor; - using FragmentB = Array; + using FragmentB = Array; using ElementC = int; using LayoutC = layout::RowMajor; @@ -715,7 +884,7 @@ struct SparseMma< using FragmentE = uint32_t; - using Operator = OpMultiplyAdd; + using Operator = OpMultiplyAddSaturate; using ArchTag = arch::Sm80; static int const kSparse = 2; @@ -743,46 +912,64 @@ struct SparseMma< int const *C = reinterpret_cast(&c); int *D = reinterpret_cast(&d); - if (id2 == 0) - asm volatile( - "mma.sp.sync.aligned.m16n8k64.row.col.s32.s8.s8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); - else - assert(0); +#if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) + if (id2 == 0) { + asm volatile( + "mma.sp::ordered_metadata.sync.aligned.m16n8k128.row.col.s32.s4.s4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + } else { + assert(0); + } +#else + if (id2 == 0) { + asm volatile( + "mma.sp.sync.aligned.m16n8k128.row.col.s32.s4.s4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + } else { + assert(0); + } +#endif #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } }; -/// Matrix multiply-add operation: S32 = S8 * U8 + S32 +/// Matrix multiply-add operation: S32 = S4 * U4 + S32 template <> struct SparseMma< - gemm::GemmShape<16,8,64>, + gemm::GemmShape<16,8,128>, 32, - int8_t, + cutlass::int4b_t, layout::RowMajor, - uint8_t, + cutlass::uint4b_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate, SPFormatType::Thread> { - using Shape = gemm::GemmShape<16,8,64>; + using Shape = gemm::GemmShape<16,8,128>; - using ElementA = int8_t; + using ElementA = cutlass::int4b_t; using LayoutA = layout::RowMajor; - using FragmentA = Array; + using FragmentA = Array; - using ElementB = uint8_t; + using ElementB = cutlass::uint4b_t; using LayoutB = layout::ColumnMajor; - using FragmentB = Array; + using FragmentB = Array; using ElementC = int; using LayoutC = layout::RowMajor; @@ -790,7 +977,7 @@ struct SparseMma< using FragmentE = uint32_t; - using Operator = OpMultiplyAdd; + using Operator = OpMultiplyAddSaturate; using ArchTag = arch::Sm80; static int const kSparse = 2; @@ -818,46 +1005,64 @@ struct SparseMma< int const *C = reinterpret_cast(&c); int *D = reinterpret_cast(&d); - if (id2 == 0) - asm volatile( - "mma.sp.sync.aligned.m16n8k64.row.col.s32.s8.u8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); - else - assert(0); +#if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) + if (id2 == 0) { + asm volatile( + "mma.sp::ordered_metadata.sync.aligned.m16n8k128.row.col.s32.s4.u4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + } else { + assert(0); + } +#else + if (id2 == 0) { + asm volatile( + "mma.sp.sync.aligned.m16n8k128.row.col.s32.s4.u4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + } else { + assert(0); + } +#endif #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } }; -/// Matrix multiply-add operation: S32 = U8 * S8 + S32 +/// Matrix multiply-add operation: S32 = U4 * S4 + S32 template <> struct SparseMma< - gemm::GemmShape<16,8,64>, + gemm::GemmShape<16,8,128>, 32, - uint8_t, + cutlass::uint4b_t, layout::RowMajor, - int8_t, + cutlass::int4b_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate, SPFormatType::Thread> { - using Shape = gemm::GemmShape<16,8,64>; + using Shape = gemm::GemmShape<16,8,128>; - using ElementA = uint8_t; + using ElementA = cutlass::uint4b_t; using LayoutA = layout::RowMajor; - using FragmentA = Array; + using FragmentA = Array; - using ElementB = int8_t; + using ElementB = cutlass::int4b_t; using LayoutB = layout::ColumnMajor; - using FragmentB = Array; + using FragmentB = Array; using ElementC = int; using LayoutC = layout::RowMajor; @@ -865,7 +1070,7 @@ struct SparseMma< using FragmentE = uint32_t; - using Operator = OpMultiplyAdd; + using Operator = OpMultiplyAddSaturate; using ArchTag = arch::Sm80; static int const kSparse = 2; @@ -893,46 +1098,64 @@ struct SparseMma< int const *C = reinterpret_cast(&c); int *D = reinterpret_cast(&d); - if (id2 == 0) - asm volatile( - "mma.sp.sync.aligned.m16n8k64.row.col.s32.u8.s8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); - else - assert(0); +#if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) + if (id2 == 0) { + asm volatile( + "mma.sp::ordered_metadata.sync.aligned.m16n8k128.row.col.s32.u4.s4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + } else { + assert(0); + } +#else + if (id2 == 0) { + asm volatile( + "mma.sp.sync.aligned.m16n8k128.row.col.s32.u4.s4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + } else { + assert(0); + } +#endif #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } }; -/// Matrix multiply-add operation: S32 = U8 * U8 + S32 +/// Matrix multiply-add operation: S32 = U4 * U4 + S32 template <> struct SparseMma< - gemm::GemmShape<16,8,64>, + gemm::GemmShape<16,8,128>, 32, - uint8_t, + cutlass::uint4b_t, layout::RowMajor, - uint8_t, + cutlass::uint4b_t, layout::ColumnMajor, int, layout::RowMajor, OpMultiplyAddSaturate, SPFormatType::Thread> { - using Shape = gemm::GemmShape<16,8,64>; + using Shape = gemm::GemmShape<16,8,128>; - using ElementA = uint8_t; + using ElementA = cutlass::uint4b_t; using LayoutA = layout::RowMajor; - using FragmentA = Array; + using FragmentA = Array; - using ElementB = uint8_t; + using ElementB = cutlass::uint4b_t; using LayoutB = layout::ColumnMajor; - using FragmentB = Array; + using FragmentB = Array; using ElementC = int; using LayoutC = layout::RowMajor; @@ -940,7 +1163,7 @@ struct SparseMma< using FragmentE = uint32_t; - using Operator = OpMultiplyAdd; + using Operator = OpMultiplyAddSaturate; using ArchTag = arch::Sm80; static int const kSparse = 2; @@ -968,630 +1191,36 @@ struct SparseMma< int const *C = reinterpret_cast(&c); int *D = reinterpret_cast(&d); - if (id2 == 0) - asm volatile( - "mma.sp.sync.aligned.m16n8k64.row.col.s32.u8.u8.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); - else - assert(0); - +#if ((__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 5)) + if (id2 == 0) { + asm volatile( + "mma.sp::ordered_metadata.sync.aligned.m16n8k128.row.col.s32.u4.u4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + } else { + assert(0); + } #else - - assert(0); + if (id2 == 0) { + asm volatile( + "mma.sp.sync.aligned.m16n8k128.row.col.s32.u4.u4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); + } else { + assert(0); + } #endif - } -}; - -//////////////////////////////////////////////////////////////////////////////// -// -// Sparse Matrix Multiply 168128 - S4 input, S32 accumulation -// -//////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation: S32 = S4 * S4 + S32 -template <> -struct SparseMma< - gemm::GemmShape<16,8,128>, - 32, - cutlass::int4b_t, - layout::RowMajor, - cutlass::int4b_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAdd, - SPFormatType::Thread> { - - using Shape = gemm::GemmShape<16,8,128>; - - using ElementA = cutlass::int4b_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = cutlass::int4b_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using FragmentE = uint32_t; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm80; - - static int const kSparse = 2; - - static int const kMetaSizeInBits = 2; - - static int const kMaxID2 = 1; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c, - uint32_t const &E, - int const id2 - ) const { - -#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - if (id2 == 0) - asm volatile( - "mma.sp.sync.aligned.m16n8k128.row.col.s32.s4.s4.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); - else - assert(0); - -#else - - assert(0); -#endif - } -}; - -/// Matrix multiply-add operation: S32 = S4 * U4 + S32 -template <> -struct SparseMma< - gemm::GemmShape<16,8,128>, - 32, - cutlass::int4b_t, - layout::RowMajor, - cutlass::uint4b_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAdd, - SPFormatType::Thread> { - - using Shape = gemm::GemmShape<16,8,128>; - - using ElementA = cutlass::int4b_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = cutlass::uint4b_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using FragmentE = uint32_t; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm80; - - static int const kSparse = 2; - - static int const kMetaSizeInBits = 2; - - static int const kMaxID2 = 1; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c, - uint32_t const &E, - int const id2 - ) const { - -#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - if (id2 == 0) - asm volatile( - "mma.sp.sync.aligned.m16n8k128.row.col.s32.s4.u4.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); - else - assert(0); - -#else - - assert(0); -#endif - } -}; - -/// Matrix multiply-add operation: S32 = U4 * S4 + S32 -template <> -struct SparseMma< - gemm::GemmShape<16,8,128>, - 32, - cutlass::uint4b_t, - layout::RowMajor, - cutlass::int4b_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAdd, - SPFormatType::Thread> { - - using Shape = gemm::GemmShape<16,8,128>; - - using ElementA = cutlass::uint4b_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = cutlass::int4b_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using FragmentE = uint32_t; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm80; - - static int const kSparse = 2; - - static int const kMetaSizeInBits = 2; - - static int const kMaxID2 = 1; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c, - uint32_t const &E, - int const id2 - ) const { - -#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - if (id2 == 0) - asm volatile( - "mma.sp.sync.aligned.m16n8k128.row.col.s32.u4.s4.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); - else - assert(0); - -#else - - assert(0); -#endif - } -}; - -/// Matrix multiply-add operation: S32 = U4 * U4 + S32 -template <> -struct SparseMma< - gemm::GemmShape<16,8,128>, - 32, - cutlass::uint4b_t, - layout::RowMajor, - cutlass::uint4b_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAdd, - SPFormatType::Thread> { - - using Shape = gemm::GemmShape<16,8,128>; - - using ElementA = cutlass::uint4b_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = cutlass::uint4b_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using FragmentE = uint32_t; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm80; - - static int const kSparse = 2; - - static int const kMetaSizeInBits = 2; - - static int const kMaxID2 = 1; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c, - uint32_t const &E, - int const id2 - ) const { - -#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - if (id2 == 0) - asm volatile( - "mma.sp.sync.aligned.m16n8k128.row.col.s32.u4.u4.s32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); - else - assert(0); - -#else - - assert(0); -#endif - } -}; - -//////////////////////////////////////////////////////////////////////////////// -// -// Sparse Matrix Multiply 168128 - S4 input, S32 accumulation - SATURATE -// -//////////////////////////////////////////////////////////////////////////////// - -/// Matrix multiply-add operation: S32 = S4 * S4 + S32 -template <> -struct SparseMma< - gemm::GemmShape<16,8,128>, - 32, - cutlass::int4b_t, - layout::RowMajor, - cutlass::int4b_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAddSaturate, - SPFormatType::Thread> { - - using Shape = gemm::GemmShape<16,8,128>; - - using ElementA = cutlass::int4b_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = cutlass::int4b_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using FragmentE = uint32_t; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm80; - - static int const kSparse = 2; - - static int const kMetaSizeInBits = 2; - - static int const kMaxID2 = 1; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c, - uint32_t const &E, - int const id2 - ) const { - -#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - if (id2 == 0) - asm volatile( - "mma.sp.sync.aligned.m16n8k128.row.col.s32.s4.s4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); - else - assert(0); - -#else - - assert(0); -#endif - } -}; - -/// Matrix multiply-add operation: S32 = S4 * U4 + S32 -template <> -struct SparseMma< - gemm::GemmShape<16,8,128>, - 32, - cutlass::int4b_t, - layout::RowMajor, - cutlass::uint4b_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAddSaturate, - SPFormatType::Thread> { - - using Shape = gemm::GemmShape<16,8,128>; - - using ElementA = cutlass::int4b_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = cutlass::uint4b_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using FragmentE = uint32_t; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm80; - - static int const kSparse = 2; - - static int const kMetaSizeInBits = 2; - - static int const kMaxID2 = 1; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c, - uint32_t const &E, - int const id2 - ) const { - -#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - if (id2 == 0) - asm volatile( - "mma.sp.sync.aligned.m16n8k128.row.col.s32.s4.u4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); - else - assert(0); - -#else - - assert(0); -#endif - } -}; - -/// Matrix multiply-add operation: S32 = U4 * S4 + S32 -template <> -struct SparseMma< - gemm::GemmShape<16,8,128>, - 32, - cutlass::uint4b_t, - layout::RowMajor, - cutlass::int4b_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAddSaturate, - SPFormatType::Thread> { - - using Shape = gemm::GemmShape<16,8,128>; - - using ElementA = cutlass::uint4b_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = cutlass::int4b_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using FragmentE = uint32_t; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm80; - - static int const kSparse = 2; - - static int const kMetaSizeInBits = 2; - - static int const kMaxID2 = 1; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c, - uint32_t const &E, - int const id2 - ) const { - -#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - if (id2 == 0) - asm volatile( - "mma.sp.sync.aligned.m16n8k128.row.col.s32.u4.s4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); - else - assert(0); - -#else - - assert(0); -#endif - } -}; - -/// Matrix multiply-add operation: S32 = U4 * U4 + S32 -template <> -struct SparseMma< - gemm::GemmShape<16,8,128>, - 32, - cutlass::uint4b_t, - layout::RowMajor, - cutlass::uint4b_t, - layout::ColumnMajor, - int, - layout::RowMajor, - OpMultiplyAddSaturate, - SPFormatType::Thread> { - - using Shape = gemm::GemmShape<16,8,128>; - - using ElementA = cutlass::uint4b_t; - using LayoutA = layout::RowMajor; - using FragmentA = Array; - - using ElementB = cutlass::uint4b_t; - using LayoutB = layout::ColumnMajor; - using FragmentB = Array; - - using ElementC = int; - using LayoutC = layout::RowMajor; - using FragmentC = Array; - - using FragmentE = uint32_t; - - using Operator = OpMultiplyAdd; - using ArchTag = arch::Sm80; - - static int const kSparse = 2; - - static int const kMetaSizeInBits = 2; - - static int const kMaxID2 = 1; - - /// Computes multiply-add - CUTLASS_HOST_DEVICE - void operator()( - FragmentC &d, - FragmentA const &a, - FragmentB const &b, - FragmentC const &c, - uint32_t const &E, - int const id2 - ) const { - -#if defined(CUTLASS_ARCH_SPARSE_MMA_SM80_ENABLED) - - uint32_t const *A = reinterpret_cast(&a); - uint32_t const *B = reinterpret_cast(&b); - - int const *C = reinterpret_cast(&c); - int *D = reinterpret_cast(&d); - - if (id2 == 0) - asm volatile( - "mma.sp.sync.aligned.m16n8k128.row.col.s32.u4.u4.s32.satfinite {%0,%1,%2,%3}, {%4,%5,%6,%7}, " - "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" - : "=r"(D[0]), "=r"(D[1]), "=r"(D[2]), "=r"(D[3]) - : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), - "r"(C[0]), "r"(C[1]), "r"(C[2]), "r"(C[3]), "r"(E)); - else - assert(0); #else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); assert(0); #endif } diff --git a/include/cutlass/arch/mma_sparse_sm89.h b/include/cutlass/arch/mma_sparse_sm89.h new file mode 100644 index 0000000000..2fae35be42 --- /dev/null +++ b/include/cutlass/arch/mma_sparse_sm89.h @@ -0,0 +1,405 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Sparse matrix multiply accumulate for SM89 +*/ + +#pragma once + +#include + +#include "mma.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/numeric_types.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#if (__CUDACC_VER_MAJOR__ > 12) || (__CUDACC_VER_MAJOR__ == 12 && __CUDACC_VER_MINOR__ >= 4) + +# define CUTLASS_ARCH_SPARSE_MMA_SM89_SUPPORTED 1 +#endif + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM89_SUPPORTED) && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ == 890) +# define CUTLASS_ARCH_SPARSE_MMA_SM89_ENABLED +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace arch { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: F32 = fe4m3 * fe4m3 + F32 +template +struct SparseMma< + gemm::GemmShape<16,8,64>, + 32, + cutlass::float_e4m3_t, + layout::RowMajor, + cutlass::float_e4m3_t, + layout::ColumnMajor, + float, + layout::RowMajor, + Operator_, + SPFormatType::Thread> { + + static_assert(platform::is_same::value || + platform::is_same::value, + "Invalid operator for SM89 FP8 instruction"); + + using Shape = gemm::GemmShape<16,8,64>; + + using ElementA = cutlass::float_e4m3_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::float_e4m3_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = float; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using FragmentE = uint32_t; + + using Operator = Operator_; + using ArchTag = arch::Sm89; + + static int const kSparse = 2; + + static int const kMetaSizeInBits = 2; + + static int const kMaxID2 = 1; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c, + uint32_t const &E, + int const id2 + ) const { + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM89_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + float const *C = reinterpret_cast(&c); + float *D = reinterpret_cast(&d); + + if (id2 == 0) { + asm volatile( + "mma.sp.sync.aligned.m16n8k64.row.col.f32.e4m3.e4m3.f32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); + } + else { + assert(0); + } +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + assert(0); +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: F32 = fe4m3 * fe5m2 + F32 +template +struct SparseMma< + gemm::GemmShape<16,8,64>, + 32, + cutlass::float_e4m3_t, + layout::RowMajor, + cutlass::float_e5m2_t, + layout::ColumnMajor, + float, + layout::RowMajor, + Operator_, + SPFormatType::Thread> { + + static_assert(platform::is_same::value || + platform::is_same::value, + "Invalid operator for SM89 FP8 instruction"); + + using Shape = gemm::GemmShape<16,8,64>; + + using ElementA = cutlass::float_e4m3_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::float_e5m2_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = float; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using FragmentE = uint32_t; + + using Operator = Operator_; + using ArchTag = arch::Sm89; + + static int const kSparse = 2; + + static int const kMetaSizeInBits = 2; + + static int const kMaxID2 = 1; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c, + uint32_t const &E, + int const id2 + ) const { + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM89_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + float const *C = reinterpret_cast(&c); + float *D = reinterpret_cast(&d); + + if (id2 == 0) { + asm volatile( + "mma.sp.sync.aligned.m16n8k64.row.col.f32.e4m3.e5m2.f32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); + } + else { + assert(0); + } +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + assert(0); +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: F32 = fe5m2 * fe4m3 + F32 +template +struct SparseMma< + gemm::GemmShape<16,8,64>, + 32, + cutlass::float_e5m2_t, + layout::RowMajor, + cutlass::float_e4m3_t, + layout::ColumnMajor, + float, + layout::RowMajor, + Operator_, + SPFormatType::Thread> { + + static_assert(platform::is_same::value || + platform::is_same::value, + "Invalid operator for SM89 FP8 instruction"); + + using Shape = gemm::GemmShape<16,8,64>; + + using ElementA = cutlass::float_e5m2_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::float_e4m3_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = float; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using FragmentE = uint32_t; + + using Operator = Operator_; + using ArchTag = arch::Sm89; + + static int const kSparse = 2; + + static int const kMetaSizeInBits = 2; + + static int const kMaxID2 = 1; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c, + uint32_t const &E, + int const id2 + ) const { + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM89_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + float const *C = reinterpret_cast(&c); + float *D = reinterpret_cast(&d); + + if (id2 == 0) { + asm volatile( + "mma.sp.sync.aligned.m16n8k64.row.col.f32.e5m2.e4m3.f32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); + } + else { + assert(0); + } +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + assert(0); +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Matrix multiply-add operation: F32 = fe5m2 * fe5m2 + F32 +template +struct SparseMma< + gemm::GemmShape<16,8,64>, + 32, + cutlass::float_e5m2_t, + layout::RowMajor, + cutlass::float_e5m2_t, + layout::ColumnMajor, + float, + layout::RowMajor, + Operator_, + SPFormatType::Thread> { + + static_assert(platform::is_same::value || + platform::is_same::value, + "Invalid operator for SM89 FP8 instruction"); + + using Shape = gemm::GemmShape<16,8,64>; + + using ElementA = cutlass::float_e5m2_t; + using LayoutA = layout::RowMajor; + using FragmentA = Array; + + using ElementB = cutlass::float_e5m2_t; + using LayoutB = layout::ColumnMajor; + using FragmentB = Array; + + using ElementC = float; + using LayoutC = layout::RowMajor; + using FragmentC = Array; + + using FragmentE = uint32_t; + + using Operator = Operator_; + using ArchTag = arch::Sm89; + + static int const kSparse = 2; + + static int const kMetaSizeInBits = 2; + + static int const kMaxID2 = 1; + + /// Computes multiply-add + CUTLASS_HOST_DEVICE + void operator()( + FragmentC &d, + FragmentA const &a, + FragmentB const &b, + FragmentC const &c, + uint32_t const &E, + int const id2 + ) const { + +#if defined(CUTLASS_ARCH_SPARSE_MMA_SM89_ENABLED) + + uint32_t const *A = reinterpret_cast(&a); + uint32_t const *B = reinterpret_cast(&b); + + float const *C = reinterpret_cast(&c); + float *D = reinterpret_cast(&d); + + if (id2 == 0) { + asm volatile( + "mma.sp.sync.aligned.m16n8k64.row.col.f32.e5m2.e5m2.f32 {%0,%1,%2,%3}, {%4,%5,%6,%7}, " + "{%8,%9,%10,%11}, {%12,%13,%14,%15}, %16, 0x0;\n" + : "=f"(D[0]), "=f"(D[1]), "=f"(D[2]), "=f"(D[3]) + : "r"(A[0]), "r"(A[1]), "r"(A[2]), "r"(A[3]), "r"(B[0]), "r"(B[1]), "r"(B[2]), "r"(B[3]), + "f"(C[0]), "f"(C[1]), "f"(C[2]), "f"(C[3]), "r"(E)); + } + else { + assert(0); + } +#else + CUTLASS_UNUSED(a); + CUTLASS_UNUSED(b); + CUTLASS_UNUSED(c); + CUTLASS_UNUSED(d); + assert(0); +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace arch +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/arch/reg_reconfig.h b/include/cutlass/arch/reg_reconfig.h new file mode 100644 index 0000000000..d2b434453e --- /dev/null +++ b/include/cutlass/arch/reg_reconfig.h @@ -0,0 +1,67 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief PTX for CTA Reconfiguration +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#ifndef CUDA_CTA_RECONFIG_ACTIVATED + #if (__CUDACC_VER_MAJOR__ >= 12 && \ + defined(__CUDA_ARCH__) && __CUDA_ARCH__ >= 900 && defined(__CUDA_ARCH_FEAT_SM90_ALL)) + #define CUDA_CTA_RECONFIG_ACTIVATED 1 + #endif +#endif + +namespace cutlass { +namespace arch { + +template +CUTLASS_DEVICE +void warpgroup_reg_alloc(){ +#if CUDA_CTA_RECONFIG_ACTIVATED + asm volatile( "setmaxnreg.inc.sync.aligned.u32 %0;\n" : : "n"(RegCount) ); +#endif +} + +template +CUTLASS_DEVICE +void warpgroup_reg_dealloc(){ +#if CUDA_CTA_RECONFIG_ACTIVATED + asm volatile( "setmaxnreg.dec.sync.aligned.u32 %0;\n" : : "n"(RegCount) ); +#endif +} + +} // namespace arch +} // namespace cutlass diff --git a/include/cutlass/arch/simd.h b/include/cutlass/arch/simd.h index abbabf94b3..f670fc293f 100644 --- a/include/cutlass/arch/simd.h +++ b/include/cutlass/arch/simd.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -34,8 +34,8 @@ #pragma once -#include "../array.h" -#include "../numeric_types.h" +#include "cutlass/arch/array.h" +#include "cutlass/arch/numeric_types.h" namespace cutlass { namespace arch { diff --git a/include/cutlass/arch/simd_sm60.h b/include/cutlass/arch/simd_sm60.h index 1de8d42b05..6e1ef20441 100644 --- a/include/cutlass/arch/simd_sm60.h +++ b/include/cutlass/arch/simd_sm60.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -50,8 +50,6 @@ template <> Array operator*(Array const &a, Array const &b) { Array d; - // TODO - return d; } @@ -60,8 +58,6 @@ template <> Array operator+(AArray const &a, Array const &b) { Array d; - // TODO - return d; } @@ -70,8 +66,6 @@ template <> Array operator-(Array const &a, Array const &b) { Array d; - // TODO - return d; } @@ -83,8 +77,6 @@ template <> Array mac(Array const &a, Array const &b, Array const &c) { Array d; - // TODO - return d; } @@ -95,8 +87,6 @@ CUTLASS_HOST_DEVICE template <> half_t dot(Array const &a, Array const &b, half_t accum) { - // TODO - return accum; } @@ -105,8 +95,6 @@ CUTLASS_HOST_DEVICE template <> float dot(Array const &a, Array const &b, float accum) { - // TODO - return accum; } diff --git a/include/cutlass/arch/simd_sm61.h b/include/cutlass/arch/simd_sm61.h index da236c16ea..b783c943ec 100644 --- a/include/cutlass/arch/simd_sm61.h +++ b/include/cutlass/arch/simd_sm61.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/include/cutlass/arch/synclog.hpp b/include/cutlass/arch/synclog.hpp new file mode 100644 index 0000000000..8cf65ad73e --- /dev/null +++ b/include/cutlass/arch/synclog.hpp @@ -0,0 +1,1324 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Synchronization event logging for race condition debugging. +*/ + +#pragma once + +#include "cutlass/detail/helper_macros.hpp" + +#if defined(__CUDACC_RTC__) +#include +#else +#include +#endif + +#if !defined(__CUDACC_RTC__) +#include +#include +#endif + +namespace cutlass { +namespace arch { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ENABLE_SYNCLOG) + +constexpr uint32_t synclog_cap = 1 << 26; + +inline std::mutex synclog_mutex; +inline std::vector synclog_buf_list; +#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) +CUTLASS_DEVICE uint32_t* synclog_buf; +#endif + +CUTLASS_DEVICE +uint32_t* synclog_alloc(uint32_t n) { + #if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) + uint32_t* buf = synclog_buf; + if (buf == nullptr) return nullptr; + uint32_t last = atomicAdd(&buf[0], n); + if (last + n < synclog_cap) return buf + last + 1; + if (last >= synclog_cap) atomicAdd(&buf[0], -n); + #endif + return nullptr; +} + +CUTLASS_DEVICE +void synclog_emit_prefix(uint32_t* to, uint32_t header, uint32_t line) { + #if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) + uint64_t time64; + asm volatile ( + "mov.u64 %0, %%globaltimer;\n" + : "=l"(time64) : + ); + to[0] = header; + to[1] = line; + to[2] = time64; + to[3] = time64 >> 32; + to[4] = threadIdx.x; + to[5] = threadIdx.y; + to[6] = threadIdx.z; + to[7] = blockIdx.x; + to[8] = blockIdx.y; + to[9] = blockIdx.z; + #endif +} + +constexpr uint32_t synclog_header_none = 0; +constexpr uint32_t synclog_length_prefix = 1 + 1 + 2 + 3 + 3; + +constexpr bool synclog_enable_syncthreads = true; +constexpr uint32_t synclog_header_syncthreads = 1; +constexpr uint32_t synclog_length_syncthreads = synclog_length_prefix + 0; + +constexpr bool synclog_enable_syncwarp = true; +constexpr uint32_t synclog_header_syncwarp = 2; +constexpr uint32_t synclog_length_syncwarp = synclog_length_prefix + 0; + +constexpr bool synclog_enable_named_barrier_arrive_and_wait = true; +constexpr uint32_t synclog_header_named_barrier_arrive_and_wait = 3; +constexpr uint32_t synclog_length_named_barrier_arrive_and_wait = synclog_length_prefix + 2; + +constexpr bool synclog_enable_named_barrier_arrive = true; +constexpr uint32_t synclog_header_named_barrier_arrive = 4; +constexpr uint32_t synclog_length_named_barrier_arrive = synclog_length_prefix + 2; + +constexpr bool synclog_enable_cluster_barrier_init = true; +constexpr uint32_t synclog_header_cluster_barrier_init = 5; +constexpr uint32_t synclog_length_cluster_barrier_init = synclog_length_prefix + 2; + +constexpr bool synclog_enable_cluster_barrier_wait = true; +constexpr uint32_t synclog_header_cluster_barrier_wait = 6; +constexpr uint32_t synclog_length_cluster_barrier_wait = synclog_length_prefix + 4; + +constexpr bool synclog_enable_cluster_barrier_test_wait = true; +constexpr uint32_t synclog_header_cluster_barrier_test_wait = 7; +constexpr uint32_t synclog_length_cluster_barrier_test_wait = synclog_length_prefix + 5; + +constexpr bool synclog_enable_cluster_barrier_try_wait = true; +constexpr uint32_t synclog_header_cluster_barrier_try_wait = 8; +constexpr uint32_t synclog_length_cluster_barrier_try_wait = synclog_length_prefix + 4; + +constexpr bool synclog_enable_cluster_barrier_arrive_cluster = true; +constexpr uint32_t synclog_header_cluster_barrier_arrive_cluster = 9; +constexpr uint32_t synclog_length_cluster_barrier_arrive_cluster = synclog_length_prefix + 5; + +constexpr bool synclog_enable_cluster_barrier_arrive = true; +constexpr uint32_t synclog_header_cluster_barrier_arrive = 10; +constexpr uint32_t synclog_length_cluster_barrier_arrive = synclog_length_prefix + 3; + +constexpr bool synclog_enable_cluster_barrier_invalidate = true; +constexpr uint32_t synclog_header_cluster_barrier_invalidate = 11; +constexpr uint32_t synclog_length_cluster_barrier_invalidate = synclog_length_prefix + 3; + +constexpr bool synclog_enable_cluster_transaction_barrier_arrive_and_expect_tx = true; +constexpr uint32_t synclog_header_cluster_transaction_barrier_arrive_and_expect_tx = 12; +constexpr uint32_t synclog_length_cluster_transaction_barrier_arrive_and_expect_tx = synclog_length_prefix + 4; + +constexpr bool synclog_enable_cluster_transaction_barrier_arrive_and_expect_tx_cluster = true; +constexpr uint32_t synclog_header_cluster_transaction_barrier_arrive_and_expect_tx_cluster = 13; +constexpr uint32_t synclog_length_cluster_transaction_barrier_arrive_and_expect_tx_cluster = synclog_length_prefix + 6; + +constexpr bool synclog_enable_cluster_transaction_barrier_expect_transaction = true; +constexpr uint32_t synclog_header_cluster_transaction_barrier_expect_transaction = 14; +constexpr uint32_t synclog_length_cluster_transaction_barrier_expect_transaction = synclog_length_prefix + 4; + +constexpr bool synclog_enable_cluster_transaction_barrier_complete_transaction = true; +constexpr uint32_t synclog_header_cluster_transaction_barrier_complete_transaction = 15; +constexpr uint32_t synclog_length_cluster_transaction_barrier_complete_transaction = synclog_length_prefix + 6; + +constexpr bool synclog_enable_fence_barrier_init = true; +constexpr uint32_t synclog_header_fence_barrier_init = 16; +constexpr uint32_t synclog_length_fence_barrier_init = synclog_length_prefix + 0; + +constexpr bool synclog_enable_fence_view_async_shared = true; +constexpr uint32_t synclog_header_fence_view_async_shared = 17; +constexpr uint32_t synclog_length_fence_view_async_shared = synclog_length_prefix + 0; + +constexpr bool synclog_enable_cp_async_wait = true; +constexpr uint32_t synclog_header_cp_async_wait = 18; +constexpr uint32_t synclog_length_cp_async_wait = synclog_length_prefix + 1; + +constexpr bool synclog_enable_cp_async_wait_all = true; +constexpr uint32_t synclog_header_cp_async_wait_all = 19; +constexpr uint32_t synclog_length_cp_async_wait_all = synclog_length_prefix + 0; + +constexpr bool synclog_enable_cp_async_fence = true; +constexpr uint32_t synclog_header_cp_async_fence = 20; +constexpr uint32_t synclog_length_cp_async_fence = synclog_length_prefix + 0; + +constexpr bool synclog_enable_cp_async_nan = true; +constexpr uint32_t synclog_header_cp_async_nan = 21; +constexpr uint32_t synclog_length_cp_async_nan = synclog_length_prefix + 4; + +constexpr bool synclog_enable_cp_async_zfill = true; +constexpr uint32_t synclog_header_cp_async_zfill = 22; +constexpr uint32_t synclog_length_cp_async_zfill = synclog_length_prefix + 5; + +constexpr bool synclog_enable_cp_async = true; +constexpr uint32_t synclog_header_cp_async = 23; +constexpr uint32_t synclog_length_cp_async = synclog_length_prefix + 5; + +constexpr bool synclog_enable_tma_load = true; +constexpr uint32_t synclog_header_tma_load = 24; +constexpr uint32_t synclog_length_tma_load = synclog_length_prefix + 4; + +constexpr bool synclog_enable_tma_store = true; +constexpr uint32_t synclog_header_tma_store = 25; +constexpr uint32_t synclog_length_tma_store = synclog_length_prefix + 3; + +constexpr bool synclog_enable_tma_store_arrive = true; +constexpr uint32_t synclog_header_tma_store_arrive = 26; +constexpr uint32_t synclog_length_tma_store_arrive = synclog_length_prefix + 0; + +constexpr bool synclog_enable_tma_store_wait = true; +constexpr uint32_t synclog_header_tma_store_wait = 27; +constexpr uint32_t synclog_length_tma_store_wait = synclog_length_prefix + 1; + +constexpr bool synclog_enable_warpgroup_arrive = true; +constexpr uint32_t synclog_header_warpgroup_arrive = 28; +constexpr uint32_t synclog_length_warpgroup_arrive = synclog_length_prefix + 0; + +constexpr bool synclog_enable_warpgroup_wait = true; +constexpr uint32_t synclog_header_warpgroup_wait = 29; +constexpr uint32_t synclog_length_warpgroup_wait = synclog_length_prefix + 1; + +constexpr bool synclog_enable_warpgroup_commit_batch = true; +constexpr uint32_t synclog_header_warpgroup_commit_batch = 30; +constexpr uint32_t synclog_length_warpgroup_commit_batch = synclog_length_prefix + 0; + +constexpr bool synclog_enable_wgmma_reg_smem = true; +constexpr uint32_t synclog_header_wgmma_reg_smem = 31; +constexpr uint32_t synclog_length_wgmma_reg_smem = synclog_length_prefix + 2; + +constexpr bool synclog_enable_wgmma_smem_smem = true; +constexpr uint32_t synclog_header_wgmma_smem_smem = 32; +constexpr uint32_t synclog_length_wgmma_smem_smem = synclog_length_prefix + 4; + +constexpr bool synclog_enable_cpasync_barrier_arrive = true; +constexpr uint32_t synclog_header_cpasync_barrier_arrive = 33; +constexpr uint32_t synclog_length_cpasync_barrier_arrive = synclog_length_prefix + 3; + +CUTLASS_DEVICE +bool synclog_condition_emit() { + #if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) + return threadIdx.x%NumThreadsPerWarp == 0 && threadIdx.y == 0 && threadIdx.z == 0 && + blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0; + #else + return 0; + #endif +} + +CUTLASS_DEVICE +bool synclog_condition_print() { + #if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) + return threadIdx.x == 0 && threadIdx.y == 0 && threadIdx.z == 0 && + blockIdx.x == 0 && blockIdx.y == 0 && blockIdx.z == 0; + #else + return false; + #endif +} + +CUTLASS_DEVICE +void synclog_print_prefix(char const* header, uint32_t at) { + #if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) + uint32_t line = synclog_buf[at + 1]; + uint32_t timeLo = synclog_buf[at + 2]; + uint32_t timeHi = synclog_buf[at + 3]; + uint32_t threadIdxX = synclog_buf[at + 4]; + uint32_t threadIdxY = synclog_buf[at + 5]; + uint32_t threadIdxZ = synclog_buf[at + 6]; + uint32_t blockIdxX = synclog_buf[at + 7]; + uint32_t blockIdxY = synclog_buf[at + 8]; + uint32_t blockIdxZ = synclog_buf[at + 9]; + printf( + "%s line=%u time=%lu thread=%u,%u,%u block=%u,%u,%u ", + header, line, + (uint64_t)timeHi << 32 | timeLo, + threadIdxX, threadIdxY, threadIdxZ, + blockIdxX, blockIdxY, blockIdxZ + ); + #endif +} + +CUTLASS_DEVICE +uint64_t synclog_mbarrier_bits(uint32_t smem_addr) { + uint64_t bits = 0; + asm volatile ( + "mbarrier.inval.shared::cta.b64 [%1];\n" + "ld.shared::cta.b64 %0, [%1];\n" + : "=l"(bits) : "r"(smem_addr) + ); + return bits; +} + +CUTLASS_DEVICE +void synclog_print_wgmma_desc(char const* str, uint32_t lo, uint32_t hi, char const* sep) { + CUTLASS_UNUSED(hi); + uint32_t smem_int_ptr = (lo & ((1 << 14) - 1)) << 4; + printf("%s_smem_int_ptr=%u%s", str, smem_int_ptr, sep); +} + +#endif // defined(CUTLASS_ENABLE_SYNCLOG) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +inline void synclog_setup() { + #if defined(CUTLASS_ENABLE_SYNCLOG) + #if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) + std::scoped_lock lock(synclog_mutex); + auto fail = [] () { + fprintf(stderr, "synclog_setup() failed\n"); + std::terminate(); + }; + int orig_device = 0; + if (cudaGetDevice(&orig_device) != cudaSuccess) { + fail(); + } + int device_count = 0; + if (cudaGetDeviceCount(&device_count) != cudaSuccess) { + fail(); + } + if (synclog_buf_list.size() == 0) { + for (int device = 0; device < device_count; device++) { + uint32_t* buf = 0; + if (cudaSetDevice(device) != cudaSuccess || + cudaMalloc(&buf, synclog_cap * sizeof(uint32_t)) != cudaSuccess) { + fail(); + } + synclog_buf_list.push_back(buf); + } + } + for (int device = 0; device < device_count; device++) { + uint32_t* buf = synclog_buf_list.at(device); + if (cudaSetDevice(device) != cudaSuccess || + cudaMemset(buf, 0, synclog_cap * sizeof(uint32_t)) != cudaSuccess || + cudaMemcpyToSymbol(synclog_buf, &buf, sizeof(buf)) != cudaSuccess) { + fail(); + } + } + if (cudaSetDevice(orig_device) != cudaSuccess) { + fail(); + } + #endif + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_syncthreads(uint32_t line) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_syncthreads) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_syncthreads); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_syncthreads, line); + #else + CUTLASS_UNUSED(line); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_syncwarp(uint32_t line) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_syncwarp) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_syncwarp); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_syncwarp, line); + #else + CUTLASS_UNUSED(line); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_named_barrier_arrive_and_wait( + uint32_t line, + uint32_t num_threads, + uint32_t barrier_id) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_named_barrier_arrive_and_wait) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_named_barrier_arrive_and_wait); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_named_barrier_arrive_and_wait, line); + to[synclog_length_prefix + 0] = num_threads; + to[synclog_length_prefix + 1] = barrier_id; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(num_threads); + CUTLASS_UNUSED(barrier_id); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_named_barrier_arrive( + uint32_t line, + uint32_t num_threads, + uint32_t barrier_id) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_named_barrier_arrive) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_named_barrier_arrive); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_named_barrier_arrive, line); + to[synclog_length_prefix + 0] = num_threads; + to[synclog_length_prefix + 1] = barrier_id; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(num_threads); + CUTLASS_UNUSED(barrier_id); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cluster_barrier_init( + uint32_t line, + uint32_t smem_addr, + uint32_t arrive_count) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cluster_barrier_init) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_init); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cluster_barrier_init, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = arrive_count; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + CUTLASS_UNUSED(arrive_count); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cluster_barrier_wait( + uint32_t line, + uint32_t smem_addr, + uint32_t phase) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cluster_barrier_wait) return; + if (!synclog_condition_emit()) return; + uint64_t bits = synclog_mbarrier_bits(smem_addr); + uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_wait); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cluster_barrier_wait, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = phase; + to[synclog_length_prefix + 2] = bits; + to[synclog_length_prefix + 3] = bits >> 32; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + CUTLASS_UNUSED(phase); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cluster_barrier_test_wait( + uint32_t line, + uint32_t smem_addr, + uint32_t phase, + uint32_t pred) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cluster_barrier_test_wait) return; + if (!synclog_condition_emit()) return; + uint64_t bits = synclog_mbarrier_bits(smem_addr); + uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_test_wait); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cluster_barrier_test_wait, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = phase; + to[synclog_length_prefix + 2] = pred; + to[synclog_length_prefix + 3] = bits; + to[synclog_length_prefix + 4] = bits >> 32; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + CUTLASS_UNUSED(phase); + CUTLASS_UNUSED(pred); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cluster_barrier_try_wait( + uint32_t line, + uint32_t smem_addr, + uint32_t phase) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cluster_barrier_try_wait) return; + if (!synclog_condition_emit()) return; + uint64_t bits = synclog_mbarrier_bits(smem_addr); + uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_try_wait); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cluster_barrier_try_wait, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = phase; + to[synclog_length_prefix + 2] = bits; + to[synclog_length_prefix + 3] = bits >> 32; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + CUTLASS_UNUSED(phase); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cluster_barrier_arrive_cluster( + uint32_t line, + uint32_t smem_addr, + uint32_t cta_id, + uint32_t pred) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cluster_barrier_arrive_cluster) return; + if (!synclog_condition_emit()) return; + uint64_t bits = synclog_mbarrier_bits(smem_addr); + uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_arrive_cluster); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cluster_barrier_arrive_cluster, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = cta_id; + to[synclog_length_prefix + 2] = pred; + to[synclog_length_prefix + 3] = bits; + to[synclog_length_prefix + 4] = bits >> 32; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + CUTLASS_UNUSED(cta_id); + CUTLASS_UNUSED(pred); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cluster_barrier_arrive( + uint32_t line, + uint32_t smem_addr) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cluster_barrier_arrive) return; + if (!synclog_condition_emit()) return; + uint64_t bits = synclog_mbarrier_bits(smem_addr); + uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_arrive); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cluster_barrier_arrive, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = bits; + to[synclog_length_prefix + 2] = bits >> 32; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cluster_barrier_invalidate( + uint32_t line, + uint32_t smem_addr) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cluster_barrier_invalidate) return; + if (!synclog_condition_emit()) return; + uint64_t bits = synclog_mbarrier_bits(smem_addr); + uint32_t* to = synclog_alloc(synclog_length_cluster_barrier_invalidate); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cluster_barrier_invalidate, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = bits; + to[synclog_length_prefix + 2] = bits >> 32; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cluster_transaction_barrier_arrive_and_expect_tx( + uint32_t line, + uint32_t smem_addr, + uint32_t transaction_bytes) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cluster_transaction_barrier_arrive_and_expect_tx) return; + if (!synclog_condition_emit()) return; + uint64_t bits = synclog_mbarrier_bits(smem_addr); + uint32_t* to = synclog_alloc(synclog_length_cluster_transaction_barrier_arrive_and_expect_tx); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cluster_transaction_barrier_arrive_and_expect_tx, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = transaction_bytes; + to[synclog_length_prefix + 2] = bits; + to[synclog_length_prefix + 3] = bits >> 32; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + CUTLASS_UNUSED(transaction_bytes); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cluster_transaction_barrier_arrive_and_expect_tx_cluster( + uint32_t line, + uint32_t smem_addr, + uint32_t transaction_bytes, + uint32_t cta_id, + uint32_t pred) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cluster_transaction_barrier_arrive_and_expect_tx_cluster) return; + if (!synclog_condition_emit()) return; + uint64_t bits = synclog_mbarrier_bits(smem_addr); + uint32_t* to = synclog_alloc(synclog_length_cluster_transaction_barrier_arrive_and_expect_tx_cluster); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cluster_transaction_barrier_arrive_and_expect_tx_cluster, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = transaction_bytes; + to[synclog_length_prefix + 2] = cta_id; + to[synclog_length_prefix + 3] = pred; + to[synclog_length_prefix + 4] = bits; + to[synclog_length_prefix + 5] = bits >> 32; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + CUTLASS_UNUSED(transaction_bytes); + CUTLASS_UNUSED(cta_id); + CUTLASS_UNUSED(pred); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cluster_transaction_barrier_expect_transaction( + uint32_t line, + uint32_t smem_addr, + uint32_t transaction_bytes) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cluster_transaction_barrier_expect_transaction) return; + if (!synclog_condition_emit()) return; + uint64_t bits = synclog_mbarrier_bits(smem_addr); + uint32_t* to = synclog_alloc(synclog_length_cluster_transaction_barrier_expect_transaction); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cluster_transaction_barrier_expect_transaction, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = transaction_bytes; + to[synclog_length_prefix + 2] = bits; + to[synclog_length_prefix + 2] = bits >> 32; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + CUTLASS_UNUSED(transaction_bytes); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cluster_transaction_barrier_complete_transaction( + uint32_t line, + uint32_t smem_addr, + uint32_t dst_cta_id, + uint32_t transaction_bytes, + uint32_t pred) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cluster_transaction_barrier_complete_transaction) return; + if (!synclog_condition_emit()) return; + uint64_t bits = synclog_mbarrier_bits(smem_addr); + uint32_t* to = synclog_alloc(synclog_length_cluster_transaction_barrier_complete_transaction); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cluster_transaction_barrier_complete_transaction, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = dst_cta_id; + to[synclog_length_prefix + 2] = transaction_bytes; + to[synclog_length_prefix + 3] = pred; + to[synclog_length_prefix + 4] = bits; + to[synclog_length_prefix + 5] = bits >> 32; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + CUTLASS_UNUSED(dst_cta_id); + CUTLASS_UNUSED(transaction_bytes); + CUTLASS_UNUSED(pred); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_fence_barrier_init(uint32_t line) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_fence_barrier_init) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_fence_barrier_init); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_fence_barrier_init, line); + #else + CUTLASS_UNUSED(line); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_fence_view_async_shared(uint32_t line) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_fence_view_async_shared) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_fence_view_async_shared); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_fence_view_async_shared, line); + #else + CUTLASS_UNUSED(line); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cp_async_wait( + uint32_t line, + uint32_t n) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cp_async_wait) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_cp_async_wait); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cp_async_wait, line); + to[synclog_length_prefix + 0] = n; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(n); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cp_async_wait_all(uint32_t line) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cp_async_wait_all) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_cp_async_wait_all); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cp_async_wait_all, line); + #else + CUTLASS_UNUSED(line); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cp_async_fence(uint32_t line) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cp_async_fence) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_cp_async_fence); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cp_async_fence, line); + #else + CUTLASS_UNUSED(line); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cp_async_nan( + uint32_t line, + uint32_t smem_addr, + const void* gmem_ptr, + uint32_t pred) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cp_async_nan) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_cp_async_nan); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cp_async_nan, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = (uint32_t)((uint64_t)gmem_ptr); + to[synclog_length_prefix + 2] = (uint32_t)((uint64_t)gmem_ptr >> 32); + to[synclog_length_prefix + 3] = pred; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + CUTLASS_UNUSED(gmem_ptr); + CUTLASS_UNUSED(pred); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cp_async_zfill( + uint32_t line, + uint32_t smem_addr, + const void* gmem_ptr, + uint32_t pred, + uint32_t size) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cp_async_zfill) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_cp_async_zfill); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cp_async_zfill, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = (uint32_t)((uint64_t)gmem_ptr); + to[synclog_length_prefix + 2] = (uint32_t)((uint64_t)gmem_ptr >> 32); + to[synclog_length_prefix + 3] = pred; + to[synclog_length_prefix + 4] = size; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + CUTLASS_UNUSED(gmem_ptr); + CUTLASS_UNUSED(pred); + CUTLASS_UNUSED(size); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cp_async( + uint32_t line, + uint32_t smem_addr, + const void* gmem_ptr, + uint32_t pred, + uint32_t size) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cp_async) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_cp_async); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cp_async, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = (uint32_t)((uint64_t)gmem_ptr); + to[synclog_length_prefix + 2] = (uint32_t)((uint64_t)gmem_ptr >> 32); + to[synclog_length_prefix + 3] = pred; + to[synclog_length_prefix + 4] = size; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + CUTLASS_UNUSED(gmem_ptr); + CUTLASS_UNUSED(pred); + CUTLASS_UNUSED(size); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_tma_load( + uint32_t line, + uint64_t gmem_int_desc, + uint32_t smem_int_mbar, + uint32_t smem_int_ptr) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_tma_load) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_tma_load); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_tma_load, line); + to[synclog_length_prefix + 0] = (uint32_t)((uint64_t)gmem_int_desc); + to[synclog_length_prefix + 1] = (uint32_t)((uint64_t)gmem_int_desc >> 32); + to[synclog_length_prefix + 2] = smem_int_mbar; + to[synclog_length_prefix + 3] = smem_int_ptr; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(gmem_int_desc); + CUTLASS_UNUSED(smem_int_mbar); + CUTLASS_UNUSED(smem_int_ptr); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_tma_store( + uint32_t line, + uint64_t gmem_int_desc, + uint32_t smem_int_ptr) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_tma_store) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_tma_store); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_tma_store, line); + to[synclog_length_prefix + 0] = (uint32_t)((uint64_t)gmem_int_desc); + to[synclog_length_prefix + 1] = (uint32_t)((uint64_t)gmem_int_desc >> 32); + to[synclog_length_prefix + 2] = smem_int_ptr; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(gmem_int_desc); + CUTLASS_UNUSED(smem_int_ptr); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_tma_store_arrive(uint32_t line) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_tma_store_arrive) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_tma_store_arrive); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_tma_store_arrive, line); + #else + CUTLASS_UNUSED(line); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_tma_store_wait( + uint32_t line, + uint32_t count) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_tma_store_wait) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_tma_store_wait); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_tma_store_wait, line); + to[synclog_length_prefix + 0] = count; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(count); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_warpgroup_arrive( + uint32_t line) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_warpgroup_arrive) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_warpgroup_arrive); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_warpgroup_arrive, line); + #else + CUTLASS_UNUSED(line); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_warpgroup_wait( + uint32_t line, + uint32_t n) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_warpgroup_wait) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_warpgroup_wait); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_warpgroup_wait, line); + to[synclog_length_prefix + 0] = n; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(n); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_warpgroup_commit_batch( + uint32_t line) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_warpgroup_commit_batch) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_warpgroup_commit_batch); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_warpgroup_commit_batch, line); + #else + CUTLASS_UNUSED(line); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_wgmma_reg_smem( + uint32_t line, + uint64_t desc_b) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_wgmma_reg_smem) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_wgmma_reg_smem); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_wgmma_reg_smem, line); + to[synclog_length_prefix + 0] = desc_b; + to[synclog_length_prefix + 1] = desc_b >> 32; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(desc_b); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_wgmma_smem_smem( + uint32_t line, + uint64_t desc_a, + uint64_t desc_b) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_wgmma_smem_smem) return; + if (!synclog_condition_emit()) return; + uint32_t* to = synclog_alloc(synclog_length_wgmma_smem_smem); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_wgmma_smem_smem, line); + to[synclog_length_prefix + 0] = desc_a; + to[synclog_length_prefix + 1] = desc_a >> 32; + to[synclog_length_prefix + 2] = desc_b; + to[synclog_length_prefix + 3] = desc_b >> 32; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(desc_a); + CUTLASS_UNUSED(desc_b); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +CUTLASS_DEVICE +void synclog_emit_cpasync_barrier_arrive( + uint32_t line, + uint32_t smem_addr) { + #if defined(CUTLASS_ENABLE_SYNCLOG) + if constexpr (!synclog_enable_cpasync_barrier_arrive) return; + if (!synclog_condition_emit()) return; + uint64_t bits = synclog_mbarrier_bits(smem_addr); + uint32_t* to = synclog_alloc(synclog_length_cpasync_barrier_arrive); + if (to == nullptr) return; + synclog_emit_prefix(to, synclog_header_cpasync_barrier_arrive, line); + to[synclog_length_prefix + 0] = smem_addr; + to[synclog_length_prefix + 1] = bits; + to[synclog_length_prefix + 2] = bits >> 32; + #else + CUTLASS_UNUSED(line); + CUTLASS_UNUSED(smem_addr); + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +#if !defined(CUTLASS_ENABLE_SYNCLOG) +CUTLASS_DEVICE +#elif defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) +static __attribute__((__noinline__)) __device__ +#else +static __attribute__((__noinline__)) +#endif +void synclog_print() { + #if defined(CUTLASS_ENABLE_SYNCLOG) + #if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) + if (synclog_buf == nullptr || !synclog_condition_print()) { + return; + } + printf("synclog start\n"); + for (uint32_t at = 1; at < synclog_cap; ) { + uint32_t header = synclog_buf[at]; + if (header == synclog_header_none) { + break; + } + printf("synclog at %u: ", at); + if constexpr (synclog_enable_syncthreads) { + if (header == synclog_header_syncthreads) { + synclog_print_prefix("syncthreads", at); + at += synclog_length_syncthreads; + printf("\n"); + continue; + } + } + if constexpr (synclog_enable_syncwarp) { + if (header == synclog_header_syncwarp) { + synclog_print_prefix("syncwarp", at); + at += synclog_length_syncwarp; + printf("\n"); + continue; + } + } + if constexpr (synclog_enable_named_barrier_arrive_and_wait) { + if (header == synclog_header_named_barrier_arrive_and_wait) { + synclog_print_prefix("named_barrier_arrive_and_wait", at); + at += synclog_length_named_barrier_arrive_and_wait; + printf("num_threads=%u barrier_id=%u\n", synclog_buf[at-2], synclog_buf[at-1]); + continue; + } + } + if constexpr (synclog_enable_named_barrier_arrive) { + if (header == synclog_header_named_barrier_arrive) { + synclog_print_prefix("named_barrier_arrive", at); + at += synclog_length_named_barrier_arrive; + printf("num_threads=%u barrier_id=%u\n", synclog_buf[at-2], synclog_buf[at-1]); + continue; + } + } + if constexpr (synclog_enable_cluster_barrier_init) { + if (header == synclog_header_cluster_barrier_init) { + synclog_print_prefix("cluster_barrier_init", at); + at += synclog_length_cluster_barrier_init; + printf("smem_addr=%u arrive_count=%u\n", synclog_buf[at-2], synclog_buf[at-1]); + continue; + } + } + if constexpr (synclog_enable_cluster_barrier_wait) { + if (header == synclog_header_cluster_barrier_wait) { + synclog_print_prefix("cluster_barrier_wait", at); + at += synclog_length_cluster_barrier_wait; + printf("smem_addr=%u phase=%u", synclog_buf[at-4], synclog_buf[at-3]); + continue; + } + } + if constexpr (synclog_enable_cluster_barrier_test_wait) { + if (header == synclog_header_cluster_barrier_test_wait) { + synclog_print_prefix("cluster_barrier_test_wait", at); + at += synclog_length_cluster_barrier_test_wait; + printf("smem_addr=%u phase=%u pred=%u", synclog_buf[at-5], synclog_buf[at-4], synclog_buf[at-3]); + continue; + } + } + if constexpr (synclog_enable_cluster_barrier_try_wait) { + if (header == synclog_header_cluster_barrier_try_wait) { + synclog_print_prefix("cluster_barrier_try_wait", at); + at += synclog_length_cluster_barrier_try_wait; + printf("smem_addr=%u phase=%u", synclog_buf[at-4], synclog_buf[at-3]); + continue; + } + } + if constexpr (synclog_enable_cluster_barrier_arrive_cluster) { + if (header == synclog_header_cluster_barrier_arrive_cluster) { + synclog_print_prefix("cluster_barrier_arrive_cluster", at); + at += synclog_length_cluster_barrier_arrive_cluster; + printf("smem_addr=%u cta_id=%u pred=%u", synclog_buf[at-5], synclog_buf[at-4], synclog_buf[at-3]); + continue; + } + } + if constexpr (synclog_enable_cluster_barrier_arrive) { + if (header == synclog_header_cluster_barrier_arrive) { + synclog_print_prefix("cluster_barrier_arrive", at); + at += synclog_length_cluster_barrier_arrive; + printf("smem_addr=%u", synclog_buf[at-3]); + continue; + } + } + if constexpr (synclog_enable_cluster_barrier_invalidate) { + if (header == synclog_header_cluster_barrier_invalidate) { + synclog_print_prefix("cluster_barrier_invalidate", at); + at += synclog_length_cluster_barrier_invalidate; + printf("smem_addr=%u", synclog_buf[at-3]); + continue; + } + } + if constexpr (synclog_enable_cluster_transaction_barrier_arrive_and_expect_tx) { + if (header == synclog_header_cluster_transaction_barrier_arrive_and_expect_tx) { + synclog_print_prefix("cluster_transaction_barrier_arrive_and_expect_tx", at); + at += synclog_length_cluster_transaction_barrier_arrive_and_expect_tx; + printf("smem_addr=%u transaction_bytes=%u", synclog_buf[at-4], synclog_buf[at-3]); + continue; + } + } + if constexpr (synclog_enable_cluster_transaction_barrier_arrive_and_expect_tx_cluster) { + if (header == synclog_header_cluster_transaction_barrier_arrive_and_expect_tx_cluster) { + synclog_print_prefix("cluster_transaction_barrier_arrive_and_expect_tx_cluster", at); + at += synclog_length_cluster_transaction_barrier_arrive_and_expect_tx_cluster; + printf("smem_addr=%u transaction_bytes=%u cta_id=%u pred=%u", synclog_buf[at-6], synclog_buf[at-5], synclog_buf[at-4], synclog_buf[at-3]); + continue; + } + } + if constexpr (synclog_enable_cluster_transaction_barrier_expect_transaction) { + if (header == synclog_header_cluster_transaction_barrier_expect_transaction) { + synclog_print_prefix("cluster_transaction_barrier_expect_transaction", at); + at += synclog_length_cluster_transaction_barrier_expect_transaction; + printf("smem_addr=%u transaction_bytes=%u", synclog_buf[at-4], synclog_buf[at-3]); + continue; + } + } + if constexpr (synclog_enable_cluster_transaction_barrier_complete_transaction) { + if (header == synclog_header_cluster_transaction_barrier_complete_transaction) { + synclog_print_prefix("cluster_transaction_barrier_complete_transaction", at); + at += synclog_length_cluster_transaction_barrier_complete_transaction; + printf("smem_addr=%u dst_cta_id=%u transaction_bytes=%u pred=%u", synclog_buf[at-6], synclog_buf[at-5], synclog_buf[at-4], synclog_buf[at-3]); + continue; + } + } + if constexpr (synclog_enable_fence_barrier_init) { + if (header == synclog_header_fence_barrier_init) { + synclog_print_prefix("fence_barrier_init", at); + at += synclog_length_fence_barrier_init; + printf("\n"); + continue; + } + } + if constexpr (synclog_enable_fence_view_async_shared) { + if (header == synclog_header_fence_view_async_shared) { + synclog_print_prefix("fence_view_async_shared", at); + at += synclog_length_fence_view_async_shared; + printf("\n"); + continue; + } + } + if constexpr (synclog_enable_cp_async_wait) { + if (header == synclog_header_cp_async_wait) { + synclog_print_prefix("cp_async_wait", at); + at += synclog_length_cp_async_wait; + printf("n=%u\n", synclog_buf[at-1]); + continue; + } + } + if constexpr (synclog_enable_cp_async_wait_all) { + if (header == synclog_header_cp_async_wait_all) { + synclog_print_prefix("cp_async_wait_all", at); + at += synclog_length_cp_async_wait_all; + printf("\n"); + continue; + } + } + if constexpr (synclog_enable_cp_async_fence) { + if (header == synclog_header_cp_async_fence) { + synclog_print_prefix("cp_async_fence", at); + at += synclog_length_cp_async_fence; + printf("\n"); + continue; + } + } + if constexpr (synclog_enable_cp_async_nan) { + if (header == synclog_header_cp_async_nan) { + synclog_print_prefix("cp_async_nan", at); + at += synclog_length_cp_async_nan; + uint64_t gmem_addr = synclog_buf[at-3]; + gmem_addr += (uint64_t)synclog_buf[at-2] << 32; + printf("smem_addr=%u gmem_addr=%llu pred=%u\n", synclog_buf[at-4], gmem_addr, synclog_buf[at-1]); + continue; + } + } + if constexpr (synclog_enable_cp_async_zfill) { + if (header == synclog_header_cp_async_zfill) { + synclog_print_prefix("cp_async_zfill", at); + at += synclog_length_cp_async_zfill; + uint64_t gmem_addr = synclog_buf[at-4]; + gmem_addr += (uint64_t)synclog_buf[at-3] << 32; + printf("smem_addr=%u gmem_addr=%llu pred=%u size=%u\n", synclog_buf[at-5], gmem_addr, synclog_buf[at-2], synclog_buf[at-1]); + continue; + } + } + if constexpr (synclog_enable_cp_async) { + if (header == synclog_header_cp_async) { + synclog_print_prefix("cp_async", at); + at += synclog_length_cp_async; + uint64_t gmem_addr = synclog_buf[at-4]; + gmem_addr += (uint64_t)synclog_buf[at-3] << 32; + printf("smem_addr=%u gmem_addr=%llu pred=%u size=%u\n", synclog_buf[at-5], gmem_addr, synclog_buf[at-2], synclog_buf[at-1]); + continue; + } + } + if constexpr (synclog_enable_tma_load) { + if (header == synclog_header_tma_load) { + synclog_print_prefix("tma_load", at); + at += synclog_length_tma_load; + uint64_t gmem_int_desc = synclog_buf[at-4]; + gmem_int_desc += (uint64_t)synclog_buf[at-3] << 32; + printf("gmem_int_desc=%llu smem_int_mbar=%u smem_int_ptr=%u\n", gmem_int_desc, synclog_buf[at-2], synclog_buf[at-1]); + continue; + } + } + if constexpr (synclog_enable_tma_store) { + if (header == synclog_header_tma_store) { + synclog_print_prefix("tma_store", at); + at += synclog_length_tma_store; + uint64_t gmem_int_desc = synclog_buf[at-3]; + gmem_int_desc += (uint64_t)synclog_buf[at-2] << 32; + printf("gmem_int_desc=%llu smem_int_ptr=%u\n", gmem_int_desc, synclog_buf[at-1]); + continue; + } + } + if constexpr (synclog_enable_tma_store_arrive) { + if (header == synclog_header_tma_store_arrive) { + synclog_print_prefix("tma_store_arrive", at); + at += synclog_length_tma_store_arrive; + printf("\n"); + continue; + } + } + if constexpr (synclog_enable_tma_store_wait) { + if (header == synclog_header_tma_store_wait) { + synclog_print_prefix("tma_store_wait", at); + at += synclog_length_tma_store_wait; + printf("count=%u\n", synclog_buf[at-1]); + continue; + } + } + if constexpr (synclog_enable_warpgroup_arrive) { + if (header == synclog_header_warpgroup_arrive) { + synclog_print_prefix("warpgroup_arrive", at); + at += synclog_length_warpgroup_arrive; + printf("\n"); + continue; + } + } + if constexpr (synclog_enable_warpgroup_wait) { + if (header == synclog_header_warpgroup_wait) { + synclog_print_prefix("warpgroup_wait", at); + at += synclog_length_warpgroup_wait; + printf("n=%u\n", synclog_buf[at-1]); + continue; + } + } + if constexpr (synclog_enable_warpgroup_commit_batch) { + if (header == synclog_header_warpgroup_commit_batch) { + synclog_print_prefix("warpgroup_commit_batch", at); + at += synclog_length_warpgroup_commit_batch; + printf("\n"); + continue; + } + } + if constexpr (synclog_enable_wgmma_reg_smem) { + if (header == synclog_header_wgmma_reg_smem) { + synclog_print_prefix("wgmma_reg_smem", at); + at += synclog_length_wgmma_reg_smem; + synclog_print_wgmma_desc("desc_b", synclog_buf[at-2], synclog_buf[at-1], ""); + printf("\n"); + continue; + } + } + if constexpr (synclog_enable_wgmma_smem_smem) { + if (header == synclog_header_wgmma_smem_smem) { + synclog_print_prefix("wgmma_smem_smem", at); + at += synclog_length_wgmma_smem_smem; + synclog_print_wgmma_desc("desc_a", synclog_buf[at-4], synclog_buf[at-3], " "); + synclog_print_wgmma_desc("desc_b", synclog_buf[at-2], synclog_buf[at-1], ""); + printf("\n"); + continue; + } + } + if constexpr (synclog_enable_cpasync_barrier_arrive) { + if (header == synclog_header_cpasync_barrier_arrive) { + synclog_print_prefix("cpasync_barrier_arrive", at); + at += synclog_length_cpasync_barrier_arrive; + printf("smem_addr=%u", synclog_buf[at-3]); + continue; + } + } + asm volatile ("brkpt;\n" ::); + } + if (synclog_buf[0] >= synclog_cap) { + printf( + "synclog was truncated (exceeded capacity of %lu bytes)\n", + (synclog_cap - 1) * sizeof(uint32_t) + ); + } + printf("synclog end\n"); + #endif + #endif // defined(CUTLASS_ENABLE_SYNCLOG) +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(CUTLASS_ENABLE_SYNCLOG) +#undef __syncthreads +#define __syncthreads() do {\ + cutlass::arch::synclog_emit_syncthreads(__LINE__);\ + __syncthreads();\ +} while (0) +#endif // defined(CUTLASS_ENABLE_SYNCLOG) + +#if defined(CUTLASS_ENABLE_SYNCLOG) +#undef __syncwarp +#define __syncwarp(...) do {\ + cutlass::arch::synclog_emit_syncwarp(__LINE__);\ + __syncwarp(__VA_ARGS__);\ +} while (0) +#endif // defined(CUTLASS_ENABLE_SYNCLOG) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace arch +} // namespace cutlass diff --git a/include/cutlass/arch/wmma.h b/include/cutlass/arch/wmma.h index 13d0a9ed97..720895f385 100644 --- a/include/cutlass/arch/wmma.h +++ b/include/cutlass/arch/wmma.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/include/cutlass/arch/wmma_sm70.h b/include/cutlass/arch/wmma_sm70.h index 59d8eeed8e..d75ee2b075 100644 --- a/include/cutlass/arch/wmma_sm70.h +++ b/include/cutlass/arch/wmma_sm70.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -34,11 +34,7 @@ #pragma once -#if defined(__CUDACC_RTC__) #include -#else -#include -#endif #include "cutlass/layout/matrix.h" //////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/arch/wmma_sm72.h b/include/cutlass/arch/wmma_sm72.h index 0895c86b3a..b644181b80 100644 --- a/include/cutlass/arch/wmma_sm72.h +++ b/include/cutlass/arch/wmma_sm72.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -34,11 +34,7 @@ #pragma once -#if defined(__CUDACC_RTC__) #include -#else -#include -#endif #include "cutlass/layout/matrix.h" //////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/arch/wmma_sm75.h b/include/cutlass/arch/wmma_sm75.h index c2c9e068f5..f603605128 100644 --- a/include/cutlass/arch/wmma_sm75.h +++ b/include/cutlass/arch/wmma_sm75.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -34,11 +34,7 @@ #pragma once -#if defined(__CUDACC_RTC__) #include -#else -#include -#endif #include "cutlass/layout/matrix.h" //////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/array.h b/include/cutlass/array.h index 347002f008..e85d19facf 100644 --- a/include/cutlass/array.h +++ b/include/cutlass/array.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -35,8 +35,9 @@ #pragma once #include "cutlass/cutlass.h" +#include "cutlass/functional.h" #include "cutlass/numeric_types.h" - +#include "cutlass/platform/platform.h" namespace cutlass { //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -47,15 +48,31 @@ template < int N, bool RegisterSized = sizeof_bits::value >= 32 > -class Array; +struct Array; + +namespace detail { + +template +struct is_Array : platform::false_type {}; + +template < + typename T, + int N, + bool RegisterSized +> +struct is_Array > : platform::true_type {}; + +template +constexpr bool is_Array_v = is_Array::value; + +} // namespace detail //////////////////////////////////////////////////////////////////////////////////////////////////// /// Defines the size of an Array<> in bits template struct sizeof_bits > { - static int const value = - int(sizeof(typename Array::Storage)) * 8 * int(Array::kStorageElements); + static constexpr int value = sizeof(Array) * 8; }; //////////////////////////////////////////////////////////////////////////////////////////////////// @@ -81,8 +98,7 @@ template < typename T, int N > -class Array { -public: +struct Array { /// Storage type using Storage = T; @@ -92,10 +108,10 @@ class Array { /// Number of storage elements //static std::size_t const kStorageElements = N; - static size_t const kStorageElements = N; + static constexpr size_t kStorageElements = N; /// Number of logical elements - static size_t const kElements = N; + static constexpr size_t kElements = N; // // C++ standard members @@ -337,26 +353,9 @@ class Array { } }; -private: - /// Internal storage Storage storage[kElements]; -public: - - #if 0 - CUTLASS_HOST_DEVICE - Array() { } - - CUTLASS_HOST_DEVICE - Array(Array const &x) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kElements; ++i) { - storage[i] = x.storage[i]; - } - } - #endif - /// Efficient clear method CUTLASS_HOST_DEVICE void clear() { @@ -442,7 +441,7 @@ class Array { CUTLASS_HOST_DEVICE void fill(T const &value) { CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < kElements; ++i) { + for (int i = 0; i < int(kElements); ++i) { storage[i] = static_cast(value); } } @@ -452,6 +451,11 @@ class Array { return iterator(storage); } + CUTLASS_HOST_DEVICE + const_iterator begin() const { + return cbegin(); + } + CUTLASS_HOST_DEVICE const_iterator cbegin() const { return const_iterator(storage); @@ -462,6 +466,11 @@ class Array { return iterator(reinterpret_cast(storage + kStorageElements)); } + CUTLASS_HOST_DEVICE + const_iterator end() const { + return cend(); + } + CUTLASS_HOST_DEVICE const_iterator cend() const { return const_iterator(reinterpret_cast(storage + kStorageElements)); @@ -472,6 +481,11 @@ class Array { return reverse_iterator(reinterpret_cast(storage + kStorageElements)); } + CUTLASS_HOST_DEVICE + const_reverse_iterator rbegin() const { + return crbegin(); + } + CUTLASS_HOST_DEVICE const_reverse_iterator crbegin() const { return const_reverse_iterator(reinterpret_cast(storage + kStorageElements)); @@ -482,6 +496,11 @@ class Array { return reverse_iterator(reinterpret_cast(storage)); } + CUTLASS_HOST_DEVICE + const_reverse_iterator rend() const { + return crend(); + } + CUTLASS_HOST_DEVICE const_reverse_iterator crend() const { return const_reverse_iterator(reinterpret_cast(storage)); @@ -493,77 +512,2092 @@ class Array { }; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// Factories //////////////////////////////////////////////////////////////////////////////////////////////////// template CUTLASS_HOST_DEVICE Array make_Array(Element x) { - Array m; - m[0] = x; - return m; + return {x}; } template CUTLASS_HOST_DEVICE Array make_Array(Element x, Element y) { - Array m; - m[0] = x; - m[1] = y; - return m; + return {x,y}; } template CUTLASS_HOST_DEVICE Array make_Array(Element x, Element y, Element z) { - Array m; - m[0] = x; - m[1] = y; - m[2] = z; - return m; + return {x,y,z}; } template CUTLASS_HOST_DEVICE Array make_Array(Element x, Element y, Element z, Element w) { - Array m; - m[0] = x; - m[1] = y; - m[2] = z; - m[3] = w; - return m; + return {x,y,z,w}; } -//////////////////////////////////////////////////////////////////////////////////////////////////// -} // namespace cutlass +///////////////////////////////////////////////////////////////////////////////////////////////// +// functional.h numeric specializations +///////////////////////////////////////////////////////////////////////////////////////////////// -//////////////////////////////////////////////////////////////////////////////////////////////////// +template +struct absolute_value_op< Array > { -#include "cutlass/array_subbyte.h" + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs) const { -//////////////////////////////////////////////////////////////////////////////////////////////////// + Array result; + absolute_value_op scalar_op; -namespace cutlass { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i]); + } -//////////////////////////////////////////////////////////////////////////////////////////////////// + return result; + } +}; -/// Aligned array type -template < - /// Element type - typename T, - /// Number of elements in the array - int N, - /// Alignment requirement in bytes - int Alignment = sizeof_bits::value * N / 8 -> -class alignas(Alignment) AlignedArray: public Array { -public: +template +struct plus> { + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, Array const &rhs) const { + + Array result; + plus scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i], rhs[i]); + } + + return result; + } + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, T const &scalar) const { + + Array result; + plus scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i], scalar); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( T const &scalar, Array const &rhs) const { + + Array result; + plus scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(scalar, rhs[i]); + } + + return result; + } }; +template +struct minus> { -//////////////////////////////////////////////////////////////////////////////////////////////////// + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, Array const &rhs) const { -} // namespace cutlass + Array result; + minus scalar_op; -//////////////////////////////////////////////////////////////////////////////////////////////////// + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i], rhs[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, T const &scalar) const { + + Array result; + minus scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i], scalar); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( T const &scalar, Array const &rhs) const { + + Array result; + minus scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(scalar, rhs[i]); + } + + return result; + } +}; + +template +struct multiplies> { + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, Array const &rhs) const { + + Array result; + multiplies scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i], rhs[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, T const &scalar) const { + + Array result; + multiplies scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i], scalar); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( T const &scalar, Array const &rhs) const { + + Array result; + multiplies scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(scalar, rhs[i]); + } + + return result; + } +}; + +template +struct maximum_absolute_value_reduction, PropogateNaN> { + + CUTLASS_HOST_DEVICE + T operator() (T const& scalar, Array const& rhs) const { + + T result = scalar; + maximum_absolute_value_reduction scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result = scalar_op(result, rhs[i]); + } + + return result; + } +}; + +template +struct scale> { + T const scaling_factor_; + + CUTLASS_HOST_DEVICE + scale(T scaling_factor) : scaling_factor_(scaling_factor) { + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const & rhs) const { + Array result; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = rhs[i] * scaling_factor_; + } + + return result; + } +}; + +template +struct divides> { + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, Array const &rhs) const { + + Array result; + divides scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i], rhs[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, T const &scalar) const { + + Array result; + divides scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i], scalar); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( T const &scalar, Array const &rhs) const { + + Array result; + divides scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(scalar, rhs[i]); + } + + return result; + } +}; + +template +struct reciprocal_approximate> { + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs) const { + + Array result; + reciprocal_approximate scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i]); + } + + return result; + } +}; + +template +struct reciprocal_approximate_ftz> { + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs) const { + + Array result; + reciprocal_approximate_ftz scalar_op; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i]); + } + + return result; + } +}; + +template +struct maximum, PropagateNaN> { + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, Array const &rhs) const { + + Array result; + maximum scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i], rhs[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, T const &scalar) const { + + Array result; + maximum scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i], scalar); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(T const &scalar, Array const &rhs) const { + + Array result; + maximum scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(scalar, rhs[i]); + } + + return result; + } +}; + +template +struct minimum, PropagateNaN> { + + CUTLASS_HOST_DEVICE + static T scalar_op(T const &lhs, T const &rhs) { + return (rhs < lhs ? rhs : lhs); + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, Array const &rhs) const { + + Array result; + minimum scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i], rhs[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, T const &scalar) const { + + Array result; + minimum scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i], scalar); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(T const &scalar, Array const &rhs) const { + + Array result; + minimum scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(scalar, rhs[i]); + } + + return result; + } +}; + +template +struct minimum_with_nan_propagation> : minimum, true> +{}; + +template +struct negate> { + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs) const { + + Array result; + negate scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(lhs[i]); + } + + return result; + } +}; + +/// Fused multiply-add +template +struct multiply_add, Array, Array> { + + CUTLASS_HOST_DEVICE + Array operator()(Array const &a, Array const &b, Array const &c) const { + + Array result; + multiply_add scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(a[i], b[i], c[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const &a, T const &scalar, Array const &c) const { + + Array result; + multiply_add scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(a[i], scalar, c[i]); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(T const &scalar, Array const &b, Array const &c) const { + + Array result; + multiply_add scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(scalar, b[i], c[i]); + } + + return result; + } +}; + +/// Fused square-and-plus +template +struct square_and_plus> { + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, Array const &rhs) const { + multiply_add, Array, Array> ma_op; + return ma_op(rhs, rhs, lhs); + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const &lhs, T const &rhs) const { + plus> plus_op; + multiplies multiplies_op; + return plus_op(multiplies_op(rhs, rhs), lhs); + } +}; + +/// Inverse-square-root +template +struct inverse_square_root> { + CUTLASS_HOST_DEVICE + Array operator()(Array const &a) const { + Array result; + inverse_square_root scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(a[i]); + } + return result; + } +}; + +template +struct inverse_square_root> { + CUTLASS_HOST_DEVICE + Array operator()(Array const & a) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = h2rsqrt(a_ptr[i]); + } + + if constexpr (N % 2) { + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); + __half d_residual = hrsqrt(a_residual_ptr[N - 1]); + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + inverse_square_root scalar_op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = scalar_op(a[i]); + } + + #endif + + return result; + } +}; + +/// Fused multiply-add-relu0 +template +struct multiply_add_relu0, Array, Array> { + + CUTLASS_HOST_DEVICE + Array operator()(Array const &a, Array const &b, Array const &c) const { + + Array result; + multiply_add scalar_op; + maximum mx; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = mx(scalar_op(a[i], b[i], c[i]), T(0)); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const &a, T const &scalar, Array const &c) const { + + Array result; + multiply_add scalar_op; + maximum mx; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = mx(scalar_op(a[i], scalar, c[i]), T(0)); + } + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(T const &scalar, Array const &b, Array const &c) const { + + Array result; + multiply_add scalar_op; + maximum mx; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = mx(scalar_op(scalar, b[i], c[i]), T(0)); + } + + return result; + } +}; + + +template +struct conjugate > { + CUTLASS_HOST_DEVICE + Array operator()(Array const &a) const { + + conjugate conj_op; + + Array ca; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + ca[i] = conj_op(a[i]); + } + return ca; + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +// functional.h numeric specializations targeting SIMD instructions in device code. +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct plus> { + CUTLASS_HOST_DEVICE + Array operator()(Array const & lhs, Array const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); + __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hadd2(lhs_ptr[i], rhs_ptr[i]); + } + + if constexpr (N % 2) { + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); + __half d_residual = __hadd(a_residual_ptr[N - 1], b_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = lhs[i] + rhs[i]; + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(half_t const & lhs, Array const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs)); + __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hadd2(lhs_pair, rhs_ptr[i]); + } + + if constexpr (N % 2) { + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); + __half d_residual = __hadd(reinterpret_cast<__half const &>(lhs), b_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = lhs + rhs[i]; + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const & lhs, half_t const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); + __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs)); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hadd2(lhs_ptr[i], rhs_pair); + } + + if constexpr (N % 2) { + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); + __half d_residual = __hadd(a_residual_ptr[N - 1], reinterpret_cast<__half const &>(rhs)); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = lhs[i] + rhs; + } + #endif + + return result; + } +}; + +template +struct minus> { + CUTLASS_HOST_DEVICE + Array operator()(Array const & lhs, Array const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); + __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hsub2(lhs_ptr[i], rhs_ptr[i]); + } + + if constexpr (N % 2) { + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); + __half d_residual = __hsub(a_residual_ptr[N - 1], b_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = lhs[i] - rhs[i]; + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(half_t const & lhs, Array const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs)); + __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hsub2(lhs_pair, rhs_ptr[i]); + } + + if constexpr (N % 2) { + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); + __half d_residual = __hsub(reinterpret_cast<__half const &>(lhs), b_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = lhs - rhs[i]; + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const & lhs, half_t const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); + __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs)); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hsub2(lhs_ptr[i], rhs_pair); + } + + if constexpr (N % 2) { + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); + __half d_residual = __hsub(a_residual_ptr[N - 1], reinterpret_cast<__half const &>(rhs)); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = lhs[i] - rhs; + } + #endif + + return result; + } +}; + +template +struct multiplies> { + CUTLASS_HOST_DEVICE + Array operator()(Array const & lhs, Array const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); + __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hmul2(lhs_ptr[i], rhs_ptr[i]); + } + + if constexpr (N % 2) { + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); + __half d_residual = __hmul(a_residual_ptr[N - 1], b_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = lhs[i] * rhs[i]; + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(half_t const & lhs, Array const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs)); + __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hmul2(lhs_pair, rhs_ptr[i]); + } + + if constexpr (N % 2) { + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); + + __half d_residual = __hmul( + reinterpret_cast<__half const &>(lhs), + b_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = lhs * rhs[i]; + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const & lhs, half_t const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); + __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs)); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hmul2(lhs_ptr[i], rhs_pair); + } + + if constexpr (N % 2) { + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); + + __half d_residual = __hmul( + a_residual_ptr[N - 1], + reinterpret_cast<__half const &>(rhs)); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = lhs[i] * rhs; + } + #endif + + return result; + } +}; + +template +struct divides> { + CUTLASS_HOST_DEVICE + Array operator()(Array const & lhs, Array const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); + __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __h2div(lhs_ptr[i], rhs_ptr[i]); + } + + if constexpr (N % 2) { + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); + + __half d_residual = __hdiv( + a_residual_ptr[N - 1], + b_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = lhs[i] / rhs[i]; + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(half_t const & lhs, Array const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs)); + __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __h2div(lhs_pair, rhs_ptr[i]); + } + + if constexpr (N % 2) { + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); + + __half d_residual = __hdiv( + reinterpret_cast<__half const &>(lhs), + b_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = lhs / rhs[i]; + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const & lhs, half_t const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); + __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs)); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __h2div(lhs_ptr[i], rhs_pair); + } + + if constexpr (N % 2) { + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); + + __half d_residual = __hdiv( + a_residual_ptr[N - 1], + reinterpret_cast<__half const &>(rhs)); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = lhs[i] / rhs; + } + #endif + + return result; + } +}; + +template +struct negate> { + CUTLASS_HOST_DEVICE + Array operator()(Array const & lhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *source_ptr = reinterpret_cast<__half2 const *>(&lhs); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hneg2(source_ptr[i]); + } + + if constexpr (N % 2) { + half_t x = -lhs[N - 1]; + __half lhs_val = reinterpret_cast<__half const &>(x); + result[N - 1] = reinterpret_cast(lhs_val); + } + + #else + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = -lhs[i]; + } + #endif + + return result; + } +}; + +/// Fused multiply-add +template +struct multiply_add, Array, Array> { + + CUTLASS_HOST_DEVICE + Array operator()( + Array const &a, + Array const &b, + Array const &c) const { + + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a); + __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b); + __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hfma2(a_ptr[i], b_ptr[i], c_ptr[i]); + } + + if constexpr (N % 2) { + + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b); + __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c); + + __half d_residual = __hfma( + a_residual_ptr[N - 1], + b_residual_ptr[N - 1], + c_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + multiply_add op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = op(a[i], b[i], c[i]); + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( + half_t const &a, + Array const &b, + Array const &c) const { + + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 a_pair = __half2half2(reinterpret_cast<__half const &>(a)); + __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b); + __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hfma2(a_pair, b_ptr[i], c_ptr[i]); + } + + if constexpr (N % 2) { + + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b); + __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c); + __half d_residual = __hfma( + reinterpret_cast<__half const &>(a), + b_residual_ptr[N - 1], + c_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + multiply_add op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = op(a, b[i], c[i]); + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( + Array const &a, + half_t const &b, + Array const &c) const { + + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a); + __half2 b_pair = __half2half2(reinterpret_cast<__half const &>(b)); + __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hfma2(a_ptr[i], b_pair, c_ptr[i]); + } + + if constexpr (N % 2) { + + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); + __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c); + + __half d_residual = __hfma( + a_residual_ptr[N - 1], + reinterpret_cast<__half const &>(b), + c_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + multiply_add op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = op(a[i], b, c[i]); + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( + Array const &a, + Array const &b, + half_t const &c) const { + + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 530) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a); + __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b); + __half2 c_pair = __half2half2(reinterpret_cast<__half const &>(c)); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hfma2(a_ptr[i], b_ptr[i], c_pair); + } + + if constexpr (N % 2) { + + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b); + + __half d_residual = __hfma( + a_residual_ptr[N - 1], + b_residual_ptr[N - 1], + reinterpret_cast<__half const &>(c)); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + multiply_add op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = op(a[i], b[i], c); + } + #endif + + return result; + } +}; + +/// Fused multiply-add-relu0 +template +struct multiply_add_relu0, Array, Array> { + + CUTLASS_HOST_DEVICE + Array operator()( + Array const &a, + Array const &b, + Array const &c) const { + + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a); + __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b); + __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hfma2_relu(a_ptr[i], b_ptr[i], c_ptr[i]); + } + + if constexpr (N % 2) { + + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b); + __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c); + + __half d_residual = __hfma_relu( + a_residual_ptr[N - 1], + b_residual_ptr[N - 1], + c_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + multiply_add op; + maximum mx; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = mx(op(a[i], b[i], c[i]), (half_t)0); + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( + half_t const &a, + Array const &b, + Array const &c) const { + + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 a_pair = __half2half2(reinterpret_cast<__half const &>(a)); + __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b); + __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hfma2_relu(a_pair, b_ptr[i], c_ptr[i]); + } + + if constexpr (N % 2) { + + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b); + __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c); + __half d_residual = __hfma_relu( + reinterpret_cast<__half const &>(a), + b_residual_ptr[N - 1], + c_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + multiply_add op; + maximum mx; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = mx(op(a, b[i], c[i]), half_t(0)); + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( + Array const &a, + half_t const &b, + Array const &c) const { + + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a); + __half2 b_pair = __half2half2(reinterpret_cast<__half const &>(b)); + __half2 const *c_ptr = reinterpret_cast<__half2 const *>(&c); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hfma2_relu(a_ptr[i], b_pair, c_ptr[i]); + } + + if constexpr (N % 2) { + + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); + __half const *c_residual_ptr = reinterpret_cast<__half const *>(&c); + + __half d_residual = __hfma_relu( + a_residual_ptr[N - 1], + reinterpret_cast<__half const &>(b), + c_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + multiply_add op; + maximum mx; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = mx(op(a[i], b, c[i]), half_t(0)); + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( + Array const &a, + Array const &b, + half_t const &c) const { + + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *a_ptr = reinterpret_cast<__half2 const *>(&a); + __half2 const *b_ptr = reinterpret_cast<__half2 const *>(&b); + __half2 c_pair = __half2half2(reinterpret_cast<__half const &>(c)); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = __hfma2_relu(a_ptr[i], b_ptr[i], c_pair); + } + + if constexpr (N % 2) { + + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&a); + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&b); + + __half d_residual = __hfma_relu( + a_residual_ptr[N - 1], + b_residual_ptr[N - 1], + reinterpret_cast<__half const &>(c)); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + multiply_add op; + maximum mx; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = mx(op(a[i], b[i], c), half_t(0)); + } + #endif + + return result; + } +}; + +template +struct minimum, PropagateNaN> { + CUTLASS_HOST_DEVICE + Array operator()(Array const & lhs, Array const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); + __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = PropagateNaN ? __hmin2_nan(lhs_ptr[i], rhs_ptr[i]) + : __hmin2(lhs_ptr[i], rhs_ptr[i]); + } + + if constexpr (N % 2) { + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); + + __half d_residual = PropagateNaN ? __hmin_nan(a_residual_ptr[N - 1], b_residual_ptr[N - 1]) + : __hmin(a_residual_ptr[N - 1], b_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + minimum mn; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = mn(lhs[i],rhs[i]); + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(half_t const & lhs, Array const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs)); + __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = PropagateNaN ? __hmin2_nan(lhs_pair, rhs_ptr[i]) + : __hmin2(lhs_pair, rhs_ptr[i]); + } + + if constexpr (N % 2) { + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); + + __half d_residual = PropagateNaN ? __hmin_nan(reinterpret_cast<__half const &>(lhs), b_residual_ptr[N - 1]) + : __hmin(reinterpret_cast<__half const &>(lhs), b_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + minimum mn; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = mn(lhs, rhs[i]); + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const & lhs, half_t const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); + __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs)); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = PropagateNaN ? __hmin2_nan(lhs_ptr[i], rhs_pair) + : __hmin2(lhs_ptr[i], rhs_pair); + } + + if constexpr (N % 2) { + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); + + __half d_residual = PropagateNaN ? __hmin_nan(a_residual_ptr[N - 1], reinterpret_cast<__half const &>(rhs)) + : __hmin(a_residual_ptr[N - 1], reinterpret_cast<__half const &>(rhs)); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + minimum mn; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = mn(lhs[i], rhs); + } + #endif + + return result; + } +}; + +template +struct maximum, PropagateNaN> { + CUTLASS_HOST_DEVICE + Array operator()(Array const & lhs, Array const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); + __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = PropagateNaN ? __hmax2_nan(lhs_ptr[i], rhs_ptr[i]) + : __hmax2(lhs_ptr[i], rhs_ptr[i]); + } + + if constexpr (N % 2) { + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); + + __half d_residual = PropagateNaN ? __hmax(a_residual_ptr[N - 1], b_residual_ptr[N - 1]) + : __hmax_nan(a_residual_ptr[N - 1], b_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + maximum mx; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = mx(lhs[i], rhs[i]); + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(half_t const & lhs, Array const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 lhs_pair = __half2half2(reinterpret_cast<__half const &>(lhs)); + __half2 const *rhs_ptr = reinterpret_cast<__half2 const *>(&rhs); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = PropagateNaN ? __hmax2_nan(lhs_pair, rhs_ptr[i]) + : __hmax2(lhs_pair, rhs_ptr[i]); + } + + if constexpr (N % 2) { + __half const *b_residual_ptr = reinterpret_cast<__half const *>(&rhs); + + __half d_residual = PropagateNaN ? __hmax_nan(reinterpret_cast<__half const &>(lhs), b_residual_ptr[N - 1]) + : __hmax(reinterpret_cast<__half const &>(lhs), b_residual_ptr[N - 1]); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + maximum mx; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = mx(lhs, rhs[i]); + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()(Array const & lhs, half_t const &rhs) const { + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + __half2 *result_ptr = reinterpret_cast<__half2 *>(&result); + __half2 const *lhs_ptr = reinterpret_cast<__half2 const *>(&lhs); + __half2 rhs_pair = __half2half2(reinterpret_cast<__half const &>(rhs)); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + result_ptr[i] = PropagateNaN ? __hmax2_nan(lhs_ptr[i], rhs_pair) + : __hmax2(lhs_ptr[i], rhs_pair); + } + + if constexpr (N % 2) { + __half const *a_residual_ptr = reinterpret_cast<__half const *>(&lhs); + + __half d_residual = PropagateNaN ? __hmax_nan(a_residual_ptr[N - 1], reinterpret_cast<__half const &>(rhs)) + : __hmax(a_residual_ptr[N - 1], reinterpret_cast<__half const &>(rhs)); + + result[N - 1] = reinterpret_cast(d_residual); + } + + #else + + maximum mx; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = mx(lhs[i], rhs); + } + #endif + + return result; + } +}; + +/// Fused multiply-add +template +struct multiply_add, Array, Array> { + + CUTLASS_HOST_DEVICE + Array operator()( + Array const &a, + Array const &b, + Array const &c) const { + + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + unsigned *result_ptr = reinterpret_cast(&result); + unsigned const *a_ptr = reinterpret_cast(&a); + unsigned const *b_ptr = reinterpret_cast(&b); + unsigned const *c_ptr = reinterpret_cast(&c); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + asm ("fma.rn.bf16x2 %0, %1, %2, %3;\n" + : "=r"(result_ptr[i]) + : "r"(a_ptr[i]), "r"(b_ptr[i]), "r"(c_ptr[i]) + ); + } + + if constexpr (N % 2) { + + uint16_t *result_ptr = reinterpret_cast(&result); + uint16_t const *a_residual_ptr = reinterpret_cast(&a); + uint16_t const *b_residual_ptr = reinterpret_cast(&b); + uint16_t const *c_residual_ptr = reinterpret_cast(&c); + + asm ("fma.rn.bf16 %0, %1, %2, %3;\n" + : "=h"(result_ptr[N - 1]) + : "h"(a_residual_ptr[N - 1]), "h"(b_residual_ptr[N - 1]), "h"(c_residual_ptr[N - 1]) + ); + } + + #else + + multiply_add op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = op(a[i], b[i], c[i]); + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( + bfloat16_t const &a, + Array const &b, + Array const &c) const { + + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + unsigned *result_ptr = reinterpret_cast(&result); + + unsigned const *b_ptr = reinterpret_cast(&b); + unsigned const *c_ptr = reinterpret_cast(&c); + + unsigned a_packed = static_cast(a.raw()); + a_packed = (a_packed | (a_packed << 16)); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + asm ("fma.rn.bf16x2 %0, %1, %2, %3;\n" + : "=r"(result_ptr[i]) + : "r"(a_packed), "r"(b_ptr[i]), "r"(c_ptr[i]) + ); + } + + if constexpr (N % 2) { + + uint16_t *result_ptr = reinterpret_cast(&result); + uint16_t const *a_residual_ptr = reinterpret_cast(&a); + uint16_t const *b_residual_ptr = reinterpret_cast(&b); + uint16_t const *c_residual_ptr = reinterpret_cast(&c); + + asm ("fma.rn.bf16 %0, %1, %2, %3;\n" + : "=h"(result_ptr[N - 1]) + : "h"(a_residual_ptr[0]), "h"(b_residual_ptr[N - 1]), "h"(c_residual_ptr[N - 1]) + ); + } + + #else + + multiply_add op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = op(a, b[i], c[i]); + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( + Array const &a, + bfloat16_t const &b, + Array const &c) const { + + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + unsigned *result_ptr = reinterpret_cast(&result); + + unsigned const *a_ptr = reinterpret_cast(&a); + unsigned const *c_ptr = reinterpret_cast(&c); + + unsigned b_packed = static_cast(b.raw()); + b_packed = (b_packed | (b_packed << 16)); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + asm ("fma.rn.bf16x2 %0, %1, %2, %3;\n" + : "=r"(result_ptr[i]) + : "r"(a_ptr[i]), "r"(b_packed), "r"(c_ptr[i]) + ); + } + + if constexpr (N % 2) { + + uint16_t *result_ptr = reinterpret_cast(&result); + uint16_t const *a_residual_ptr = reinterpret_cast(&a); + uint16_t const *b_residual_ptr = reinterpret_cast(&b); + uint16_t const *c_residual_ptr = reinterpret_cast(&c); + + asm ("fma.rn.bf16 %0, %1, %2, %3;\n" + : "=h"(result_ptr[N - 1]) + : "h"(a_residual_ptr[N - 1]), "h"(b_residual_ptr[0]), "h"(c_residual_ptr[N - 1]) + ); + } + + #else + + multiply_add op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = op(a[i], b, c[i]); + } + #endif + + return result; + } + + CUTLASS_HOST_DEVICE + Array operator()( + Array const &a, + Array const &b, + bfloat16_t const &c) const { + + Array result; + #if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + + unsigned *result_ptr = reinterpret_cast(&result); + + unsigned const *a_ptr = reinterpret_cast(&a); + unsigned const *b_ptr = reinterpret_cast(&b); + + unsigned c_packed = static_cast(c.raw()); + c_packed = (c_packed | (c_packed << 16)); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 2; ++i) { + asm ("fma.rn.bf16x2 %0, %1, %2, %3;\n" + : "=r"(result_ptr[i]) + : "r"(a_ptr[i]), "r"(b_ptr[i]), "r"(c_packed) + ); + } + + if constexpr (N % 2) { + + uint16_t *result_ptr = reinterpret_cast(&result); + uint16_t const *a_residual_ptr = reinterpret_cast(&a); + uint16_t const *b_residual_ptr = reinterpret_cast(&b); + uint16_t const *c_residual_ptr = reinterpret_cast(&c); + + asm ("fma.rn.bf16 %0, %1, %2, %3;\n" + : "=h"(result_ptr[N - 1]) + : "h"(a_residual_ptr[N - 1]), "h"(b_residual_ptr[N - 1]), "h"(c_residual_ptr[0]) + ); + } + + #else + + multiply_add op; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N; ++i) { + result[i] = op(a[i], b[i], c); + } + #endif + + return result; + } +}; + + +/// bit_and +template +struct bit_and> { + CUTLASS_HOST_DEVICE + Array operator()(Array const &a, Array const &b) const { + using ArrayType = Array; + using Storage = typename ArrayType::Storage; + ArrayType result; + + Storage *result_data = result.raw_data(); + Storage const *a_data = a.raw_data(); + Storage const *b_data = b.raw_data(); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ArrayType::kStorageElements; ++i) { + result_data[i] = (a_data[i] & b_data[i]); + } + + return result; + } +}; + + +/// bit_or +template +struct bit_or> { + CUTLASS_HOST_DEVICE + Array operator()(Array const &a, Array const &b) const { + using ArrayType = Array; + using Storage = typename ArrayType::Storage; + ArrayType result; + + Storage *result_data = result.raw_data(); + Storage const *a_data = a.raw_data(); + Storage const *b_data = b.raw_data(); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ArrayType::kStorageElements; ++i) { + result_data[i] = (a_data[i] | b_data[i]); + } + + return result; + } +}; + + +/// bit_not +template +struct bit_not> { + CUTLASS_HOST_DEVICE + Array operator()(Array const &a) const { + using ArrayType = Array; + using Storage = typename ArrayType::Storage; + ArrayType result; + + Storage *result_data = result.raw_data(); + Storage const *a_data = a.raw_data(); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ArrayType::kStorageElements; ++i) { + result_data[i] = (~a_data[i]); + } + + return result; + } +}; + + +/// bit_xor +template +struct bit_xor> { + CUTLASS_HOST_DEVICE + Array operator()(Array const &a, Array const &b) const { + using ArrayType = Array; + using Storage = typename ArrayType::Storage; + ArrayType result; + + Storage *result_data = result.raw_data(); + Storage const *a_data = a.raw_data(); + Storage const *b_data = b.raw_data(); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < ArrayType::kStorageElements; ++i) { + result_data[i] = (a_data[i] ^ b_data[i]); + } + + return result; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// Operator overloads +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_HOST_DEVICE +Array operator+(Array const &lhs, Array const &rhs) { + plus> op; + return op(lhs, rhs); +} + +template +CUTLASS_HOST_DEVICE +Array operator+(T const &lhs, Array const &rhs) { + plus> op; + return op(lhs, rhs); +} + +template +CUTLASS_HOST_DEVICE +Array operator+(Array const &lhs, T const &rhs) { + plus> op; + return op(lhs, rhs); +} + +template +CUTLASS_HOST_DEVICE +Array operator-(Array const &lhs, Array const &rhs) { + minus> op; + return op(lhs, rhs); +} + +template +CUTLASS_HOST_DEVICE +Array operator-(Array const &lhs) { + negate> op; + return op(lhs); +} + +template +CUTLASS_HOST_DEVICE +Array operator*(Array const &lhs, Array const &rhs) { + multiplies> op; + return op(lhs, rhs); +} + +template +CUTLASS_HOST_DEVICE +Array operator*(T lhs, Array const &rhs) { + multiplies> op; + return op(lhs, rhs); +} + +template +CUTLASS_HOST_DEVICE +Array operator*(Array const &lhs, T rhs) { + multiplies> op; + return op(lhs, rhs); +} + +template +CUTLASS_HOST_DEVICE +Array operator/(Array const &lhs, Array const &rhs) { + divides> op; + return op(lhs, rhs); +} + +template +CUTLASS_HOST_DEVICE +Array fma(Array const &a, Array const &b, Array const &c) { + multiply_add> op; + return op(a, b, c); +} + +template +CUTLASS_HOST_DEVICE +Array fma(T a, Array const &b, Array const &c) { + multiply_add> op; + return op(a, b, c); +} + +template +CUTLASS_HOST_DEVICE +Array fma(Array const &a, T b, Array const &c) { + multiply_add> op; + return op(a, b, c); +} + +template +CUTLASS_HOST_DEVICE +Array fma(Array const &a, Array const &b, T c) { + multiply_add> op; + return op(a, b, c); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// AlignedArray +//////////////////////////////////////////////////////////////////////////////////////////////////// + +/// Aligned array type +template < + /// Element type + typename T, + /// Number of elements in the array + int N, + /// Alignment requirement in bytes + int Alignment = ( sizeof_bits::value * N + 7 ) / 8 +> +class alignas(Alignment) AlignedArray: public Array { +public: + +}; + +} // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#include "cutlass/array_subbyte.h" + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/array_planar_complex.h b/include/cutlass/array_planar_complex.h index 0d8d2c9899..2dd8aa84e1 100644 --- a/include/cutlass/array_planar_complex.h +++ b/include/cutlass/array_planar_complex.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -51,13 +51,12 @@ struct ArrayPlanarComplex { using Element = Element_; /// Number of logical elements - static size_t const kElements = N; + static constexpr size_t kElements = N; /// Underlying Fragment of real-valued elemenets - using ArrayReal = Array; + using ArrayReal = cutlass::Array; public: - /// Fragment of real-valued elements representing the real part ArrayReal real; @@ -65,19 +64,6 @@ struct ArrayPlanarComplex { ArrayReal imag; public: - - /// Ctor - CUTLASS_HOST_DEVICE - ArrayPlanarComplex() { } - - /// Ctor - CUTLASS_HOST_DEVICE - ArrayPlanarComplex( - ArrayReal const &real_, - ArrayReal const &imag_ - ): - real(real_), imag(imag_) { } - /// Sets the array to zero efficiently CUTLASS_HOST_DEVICE void clear() { @@ -93,7 +79,7 @@ template CUTLASS_HOST_DEVICE ArrayPlanarComplex make_ArrayPlanarComplex(Array const &real, Array const &imag) { - return ArrayPlanarComplex(real, imag); + return ArrayPlanarComplex{real, imag}; } ///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/array_subbyte.h b/include/cutlass/array_subbyte.h index 12b1665f01..d2e0e5efdb 100644 --- a/include/cutlass/array_subbyte.h +++ b/include/cutlass/array_subbyte.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -48,10 +48,8 @@ template < typename T, int N > -class Array { -public: - - static int const kSizeBits = sizeof_bits::value * N; +struct Array { + static constexpr int kSizeBits = sizeof_bits::value * N; /// Storage type using Storage = typename platform::conditional< @@ -68,16 +66,16 @@ class Array { using Element = T; /// Number of logical elements per stored object - static int const kElementsPerStoredItem = int(sizeof(Storage) * 8) / sizeof_bits::value; + static constexpr int kElementsPerStoredItem = int(sizeof(Storage) * 8) / sizeof_bits::value; /// Number of storage elements - static size_t const kStorageElements = N / kElementsPerStoredItem; + static constexpr size_t kStorageElements = (N + kElementsPerStoredItem - 1) / kElementsPerStoredItem; /// Number of logical elements - static size_t const kElements = N; + static constexpr size_t kElements = N; /// Bitmask for covering one item - static Storage const kMask = ((Storage(1) << sizeof_bits::value) - 1); + static constexpr Storage kMask = ((Storage(1) << sizeof_bits::value) - 1); // // C++ standard members with pointer types removed @@ -96,16 +94,14 @@ class Array { /// Reference object inserts or extracts sub-byte items class reference { /// Pointer to storage element - Storage *ptr_; + Storage *ptr_{nullptr}; /// Index into elements packed into Storage object - int idx_; + int idx_{0}; public: - /// Default ctor - CUTLASS_HOST_DEVICE - reference(): ptr_(nullptr), idx_(0) { } + reference() = default; /// Ctor CUTLASS_HOST_DEVICE @@ -114,11 +110,38 @@ class Array { /// Assignment CUTLASS_HOST_DEVICE reference &operator=(T x) { + // `*ptr_ & kUpdateMask` will read ptr_ before write to it + // This means code pattern like + // + // ```cpp + // Array result; + // result[0] = xxx; + // ``` + // + // Will leads to compiler warning on use of unintialized member variable. Although we know + // this read of uninitialized member variable is harmeless. + +#if defined(__clang__) +# pragma clang diagnostic push +# pragma clang diagnostic ignored "-Wuninitialized" +#elif defined(__GNUC__) +# pragma GCC diagnostic push +# pragma GCC diagnostic ignored "-Wuninitialized" +# pragma GCC diagnostic ignored "-Wmaybe-uninitialized" +#endif + Storage item = (reinterpret_cast(x) & kMask); Storage kUpdateMask = Storage(~(kMask << (idx_ * sizeof_bits::value))); + *ptr_ = Storage(((*ptr_ & kUpdateMask) | (item << idx_ * sizeof_bits::value))); +#if defined(__clang__) +# pragma clang diagnostic pop +#elif defined(__GNUC__) +# pragma GCC diagnostic pop +#endif + return *this; } @@ -151,16 +174,14 @@ class Array { class const_reference { /// Pointer to storage element - Storage const *ptr_; + Storage const *ptr_{nullptr}; /// Index into elements packed into Storage object - int idx_; + int idx_{0}; public: - /// Default ctor - CUTLASS_HOST_DEVICE - const_reference(): ptr_(nullptr), idx_(0) { } + const_reference() = default; /// Ctor CUTLASS_HOST_DEVICE @@ -200,15 +221,14 @@ class Array { class iterator { /// Pointer to storage element - Storage *ptr_; + Storage *ptr_{nullptr}; /// Index into elements packed into Storage object - int idx_; + int idx_{0}; public: - CUTLASS_HOST_DEVICE - iterator(): ptr_(nullptr), idx_(0) { } + iterator() = default; CUTLASS_HOST_DEVICE iterator(Storage *ptr, int idx = 0): ptr_(ptr), idx_(idx) { } @@ -279,15 +299,14 @@ class Array { class const_iterator { /// Pointer to storage element - Storage const *ptr_; + Storage const *ptr_{nullptr}; /// Index into elements packed into Storage object - int idx_; + int idx_{0}; public: - CUTLASS_HOST_DEVICE - const_iterator(): ptr_(nullptr), idx_(0) { } + const_iterator() = default; CUTLASS_HOST_DEVICE const_iterator(Storage const *ptr, int idx = 0): ptr_(ptr), idx_(idx) { } @@ -358,62 +377,36 @@ class Array { class reverse_iterator { /// Pointer to storage element - Storage *ptr_; + Storage *ptr_{nullptr}; /// Index into elements packed into Storage object - int idx_; + int idx_{0}; public: - CUTLASS_HOST_DEVICE - reverse_iterator(): ptr_(nullptr), idx_(0) { } + reverse_iterator() = default; CUTLASS_HOST_DEVICE reverse_iterator(Storage *ptr, int idx = 0): ptr_(ptr), idx_(idx) { } - - // TODO }; /// Bidirectional constant iterator over elements class const_reverse_iterator { /// Pointer to storage element - Storage const *ptr_; + Storage const *ptr_{nullptr}; /// Index into elements packed into Storage object - int idx_; + int idx_{0}; public: - CUTLASS_HOST_DEVICE - const_reverse_iterator(): ptr_(nullptr), idx_(0) { } + const_reverse_iterator() = default; CUTLASS_HOST_DEVICE const_reverse_iterator(Storage const *ptr, int idx = 0): ptr_(ptr), idx_(idx) { } - - // TODO }; -private: - - /// Internal storage - Storage storage[kStorageElements]; - -public: - - #if 0 - CUTLASS_HOST_DEVICE - Array() { } - - CUTLASS_HOST_DEVICE - Array(Array const &x) { - CUTLASS_PRAGMA_UNROLL - for (int i = 0; i < int(kStorageElements); ++i) { - storage[i] = x.storage[i]; - } - } - #endif - /// Efficient clear method CUTLASS_HOST_DEVICE void clear() { @@ -484,7 +477,6 @@ class Array { return storage; } - CUTLASS_HOST_DEVICE constexpr bool empty() const { return !kElements; @@ -555,14 +547,15 @@ class Array { return const_reverse_iterator(storage); } - // - // Comparison operators - // - +private: + /// Internal storage + Storage storage[kStorageElements]; }; //////////////////////////////////////////////////////////////////////////////////////////////////// +//////////////////////////////////////////////////////////////////////////////////////////////////// + } // namespace cutlass //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/barrier.h b/include/cutlass/barrier.h new file mode 100644 index 0000000000..6f2373b6df --- /dev/null +++ b/include/cutlass/barrier.h @@ -0,0 +1,377 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Implementation of a CTA-wide barrier for inter-CTA synchronization. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/barrier.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +namespace detail { + +// +// Utilities for abstracting synchronization methods for barriers +// + +struct SyncthreadsSync { + CUTLASS_DEVICE + static void sync() { + __syncthreads(); + } +}; + +struct SyncwarpSync { + CUTLASS_DEVICE + static void sync() { + __syncwarp(); + } +}; + +template < + int ThreadCount, + int BarrierId +> +struct NamedBarrierSync { + CUTLASS_DEVICE + static void sync() { + cutlass::arch::NamedBarrier::sync(ThreadCount, static_cast(BarrierId)); + } +}; + +} // namepspace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Group or CTA-wide semaphore for inter-CTA synchronization. +template +struct GenericBarrier { + +public: + + /// Flag type + using T = int; + + /// Initial flag value + static const T INIT = 0; + + +protected: + + /// Load flag, as a strong acquire operation (int specialization) + CUTLASS_DEVICE + static int ld_acquire(int *ptr) + { + int state = 0; + +#if (__CUDA_ARCH__ >= 700) + /// SM70 and newer use memory consistency qualifiers + + // Acquire pattern using acquire modifier + asm volatile ("ld.global.acquire.gpu.b32 %0, [%1];\n" : "=r"(state) : "l"(ptr)); + +#else + asm volatile ("ld.cg.global.b32 %0, [%1];\n" : "=r"(state) : "l"(ptr)); +#endif // (__CUDA_ARCH__ >= 700) + + return state; + } + + + /// Reduce into flag, with release pattern (int specialization) + CUTLASS_DEVICE + static void red_release(int *ptr, int val) + { +#if (__CUDA_ARCH__ >= 700) + /// SM70 and newer use memory consistency qualifiers + + // Release pattern using acq_rel fence + relaxed modifier. (The fence also releases data + // that was weakly-written by other threads prior to the last syncthreads) + asm volatile ("fence.acq_rel.gpu;\n"); + asm volatile ("red.relaxed.gpu.global.add.s32 [%0], %1;\n" : : "l"(ptr), "r"(val)); + +#else + __threadfence(); + atomicAdd(ptr, val); +#endif // (__CUDA_ARCH__ >= 700) + } + + +public: + + /// Uses thread[0] to wait for at least the specified count of signals on the given flag counter + CUTLASS_DEVICE + static void wait_lt(void *lock_ptr, int thread_idx, int flag_idx, int count) + { + T *flag_ptr = reinterpret_cast(lock_ptr) + flag_idx; + + if (thread_idx == 0) + { + // Spin-loop + #pragma unroll 1 + while(ld_acquire(flag_ptr) < count) {} + } + + Sync::sync(); + } + + /// Uses thread[0] to wait for at least the specified count of signals on the given flag counter + CUTLASS_DEVICE + static void wait_eq(void *lock_ptr, int thread_idx, int flag_idx, T val = 1) + { + T *flag_ptr = reinterpret_cast(lock_ptr) + flag_idx; + + if (thread_idx == 0) + { + // Spin-loop + #pragma unroll 1 + while(ld_acquire(flag_ptr) != val) {} + } + Sync::sync(); + } + + /// Uses thread[0] to wait for the specified count of signals on the given flag counter + CUTLASS_DEVICE + static void wait_eq_reset(void *lock_ptr, int thread_idx, int flag_idx, T val = 1) { + T *flag_ptr = reinterpret_cast(lock_ptr) + flag_idx; + + if (thread_idx == 0) + { + // Spin-loop + #pragma unroll 1 + while(atomicCAS(flag_ptr, val, 0) != val) {} + } + + Sync::sync(); + } + + /// Increment the arrival count for a flag + CUTLASS_DEVICE + static void arrive_inc(void *lock_ptr, int thread_idx, int flag_idx, int val = 1) + { + T* flag_ptr = reinterpret_cast(lock_ptr) + flag_idx; + + Sync::sync(); + + if (thread_idx == 0) + { + red_release(flag_ptr, val); + } + } + + + /// Increment the arrival counts for a range of flags + CUTLASS_DEVICE + static void arrive_range_inc(void *lock_ptr, int thread_idx, int first_flag_idx, int count = 1, int val = 1) + { + int flag_idx = first_flag_idx + thread_idx; + T* flag_ptr = reinterpret_cast(lock_ptr) + flag_idx; + + // Barrier to make sure all other threads in group have written their data + Sync::sync(); + + // Select threads increment their flags + if (thread_idx < count) { + red_release(flag_ptr, val); + } + } +}; + +using Barrier = GenericBarrier; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/** Structure for managing multiple NamedBarriers to be used by different warp groups, allowing + * runtime index values to be used to call into named barriers with compile-time-constant IDs. + * + * @param ThreadCount_ Number of threads that will wait on a NamedBarrier with a given ID + * @param Offset Value added to the ID passed in by the user to determine the NamedBarrier ID to call into + * @param MaxNumNamedBarriers The maximum number of unique barrier IDs that will be requested on this type +**/ +template < + uint32_t ThreadCount_, + uint32_t Offset = 0, + uint32_t MaxNumNamedBarriers = 16 +> +struct NamedBarrierManager { + + static_assert(MaxNumNamedBarriers <= arch::NamedBarrier::HardwareMaxNumNamedBarriers); + static_assert(MaxNumNamedBarriers + Offset <= arch::NamedBarrier::HardwareMaxNumNamedBarriers, "Barrier IDs cannot exceed 15"); + + // Number of threads participating in the barrier + static constexpr uint32_t ThreadCount = ThreadCount_; + + template + using BarrierSync = cutlass::GenericBarrier>; + + // Underlying type used by all barriers for synchronization. Does not depend on + // template parameter BarrierId, so passing in 0 suffices. + using T = typename BarrierSync<0>::T; + + using IntegerSequence = cute::make_integer_sequence; + + CUTLASS_DEVICE + static + void wait_lt(uint32_t idx, void *lock_ptr, int thread_idx, int flag_idx, int count) { + wait_lt_helper(idx, lock_ptr, thread_idx, flag_idx, count, IntegerSequence{}); + } + + CUTLASS_DEVICE + static void + wait_eq(uint32_t idx, void *lock_ptr, int thread_idx, int flag_idx, T val = 1) { + wait_eq_helper(idx, lock_ptr, thread_idx, flag_idx, val, IntegerSequence{}); + } + + CUTLASS_DEVICE + static void + wait_eq_reset(uint32_t idx, void *lock_ptr, int thread_idx, int flag_idx, T val = 1) { + wait_eq_helper(idx, lock_ptr, thread_idx, flag_idx, val, IntegerSequence{}); + } + + CUTLASS_DEVICE + static void + arrive_inc(uint32_t idx, void *lock_ptr, int thread_idx, int flag_idx, int val = 1) { + arrive_inc_helper(idx, lock_ptr, thread_idx, flag_idx, val, IntegerSequence{}); + } + + CUTLASS_DEVICE + static void + arrive_range_inc(uint32_t idx, void *lock_ptr, int thread_idx, int first_flag_idx, int count = 1, int val = 1) { + arrive_range_inc_helper(idx, lock_ptr, thread_idx, first_flag_idx, count, val, IntegerSequence{}); + } + +private: + CUTLASS_DEVICE + static void + check_barrier_in_range([[maybe_unused]] uint32_t idx) { + assert((idx < MaxNumNamedBarriers) && "Index exceeds barrier count"); + } + + template + CUTLASS_DEVICE + static void + wait_lt_helper(uint32_t idx, void *lock_ptr, int thread_idx, int flag_idx, int count, cute::integer_sequence) { + check_barrier_in_range(idx); + ((Idx == idx && (BarrierSync::wait_lt(lock_ptr, thread_idx, flag_idx, count), true)) || ...); + } + + template + CUTLASS_DEVICE + static void + wait_eq_helper(uint32_t idx, void *lock_ptr, int thread_idx, int flag_idx, T val, cute::integer_sequence) { + check_barrier_in_range(idx); + if constexpr (Reset) { + ((Idx == idx && (BarrierSync::wait_eq_reset(lock_ptr, thread_idx, flag_idx, val), true)) || ...); + } + else { + ((Idx == idx && (BarrierSync::wait_eq(lock_ptr, thread_idx, flag_idx, val), true)) || ...); + } + } + + template + CUTLASS_DEVICE + static void + arrive_inc_helper(uint32_t idx, void *lock_ptr, int thread_idx, int flag_idx, int val, cute::integer_sequence) { + check_barrier_in_range(idx); + ((Idx == idx && (BarrierSync::arrive_inc(lock_ptr, thread_idx, flag_idx, val), true)) || ...); + } + + template + CUTLASS_DEVICE + static void + arrive_range_inc_helper(uint32_t idx, void *lock_ptr, int thread_idx, int first_flag_idx, int count, int val, cute::integer_sequence) { + check_barrier_in_range(idx); + ((Idx == idx && (BarrierSync::arrive_range_inc(lock_ptr, thread_idx, first_flag_idx, count, val), true)) || ...); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/** Structure for synchronizing via contiguous barriers (e.g., __syncwarp, __syncthreads) + * via an API that mirrors that of NamedBarrierManager + * + * @param Synchronizer Synchronization helper exposing a `sync()` method to perform synchronization +**/ +template < + class Synchronizer, + uint32_t ThreadCount_ +> +struct SyncManager { + + // Number of threads participating in the barrier + static constexpr uint32_t ThreadCount = ThreadCount_; + + using BarrierSync = cutlass::GenericBarrier; + + // Underlying type used by all barriers for synchronization. + using T = typename BarrierSync::T; + + CUTLASS_DEVICE + static + void wait_lt(uint32_t, void *lock_ptr, int thread_idx, int flag_idx, int count) { + BarrierSync::wait_lt(lock_ptr, thread_idx, flag_idx, count); + } + + CUTLASS_DEVICE + static void + wait_eq(uint32_t, void *lock_ptr, int thread_idx, int flag_idx, T val = 1) { + BarrierSync::wait_eq(lock_ptr, thread_idx, flag_idx, val); + } + + CUTLASS_DEVICE + static void + wait_eq_reset(uint32_t, void *lock_ptr, int thread_idx, int flag_idx, T val = 1) { + BarrierSync::wait_eq_reset(lock_ptr, thread_idx, flag_idx, val); + } + + CUTLASS_DEVICE + static void + arrive_inc(uint32_t, void *lock_ptr, int thread_idx, int flag_idx, int val = 1) { + BarrierSync::arrive_inc(lock_ptr, thread_idx, flag_idx, val); + } + + CUTLASS_DEVICE + static void + arrive_range_inc(uint32_t idx, void *lock_ptr, int thread_idx, int first_flag_idx, int count = 1, int val = 1) { + BarrierSync::arrive_range_inc(lock_ptr, thread_idx, first_flag_idx, count, val); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/bfloat16.h b/include/cutlass/bfloat16.h index e4d20efc1e..5af6d3ab80 100644 --- a/include/cutlass/bfloat16.h +++ b/include/cutlass/bfloat16.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -33,16 +33,21 @@ \brief Defines a proxy class for storing non-standard 16-bit floating point values with 8 bits of exponent and 7 bit of mantissa. */ + #pragma once -#if !defined(__CUDACC_RTC__) +#if defined(__CUDACC_RTC__) +#include "cutlass/floating_point_nvrtc.h" +#else #include #include #include #include #endif +#include #include "cutlass/cutlass.h" +#include "cutlass/platform/platform.h" namespace cutlass { @@ -70,9 +75,41 @@ struct alignas(2) bfloat16_t { return h; } +private: + struct from_32_bit_integer_t {}; + static constexpr from_32_bit_integer_t from_32_bit_integer{}; + + template + CUTLASS_HOST_DEVICE + explicit bfloat16_t(from_32_bit_integer_t, T x) { + static_assert(cutlass::platform::is_integral::value && sizeof(T) == 4, "Requires 32-bit integer"); + + float flt = static_cast(x); + uint32_t bits; + + #if defined(__CUDA_ARCH__) + bits = reinterpret_cast(flt); + #else + std::memcpy(&bits, &flt, sizeof(bits)); + #endif + + storage = uint16_t(bits >> 16); + } + +public: /// Default constructor + bfloat16_t() = default; + + /// Reinterpret cast from CUDA's __nv_bfloat16 type CUTLASS_HOST_DEVICE - bfloat16_t() : storage(0) { } + explicit bfloat16_t(__nv_bfloat16 const & x) { + #if defined(__CUDA_ARCH__) + storage = reinterpret_cast(x); + #else + __nv_bfloat16_raw raw(x); + std::memcpy(&storage, &raw.x, sizeof(storage)); + #endif + } /// Floating-point conversion - round toward nearest CUTLASS_HOST_DEVICE @@ -117,18 +154,10 @@ struct alignas(2) bfloat16_t { /// Integer conversion - round toward nearest CUTLASS_HOST_DEVICE - explicit bfloat16_t(int x) { - float flt = static_cast(x); - uint32_t bits; - - #if defined(__CUDA_ARCH__) - bits = reinterpret_cast(flt); - #else - std::memcpy(&bits, &flt, sizeof(bits)); - #endif + explicit bfloat16_t(int x) : bfloat16_t(from_32_bit_integer, x) {} - storage = uint16_t(bits >> 16); - } + CUTLASS_HOST_DEVICE + explicit bfloat16_t(uint32_t x) : bfloat16_t(from_32_bit_integer, x) {} /// Converts to float CUTLASS_HOST_DEVICE @@ -161,6 +190,12 @@ struct alignas(2) bfloat16_t { return (float(*this) != 0.0f); } + /// Bitcasts to CUDA's bf16 type + CUTLASS_DEVICE + __nv_bfloat16 to_nv_bfloat16() const { + return reinterpret_cast<__nv_bfloat16 const &>(storage); + } + /// Obtains raw bits CUTLASS_HOST_DEVICE uint16_t raw() const { @@ -200,7 +235,7 @@ bool signbit(cutlass::bfloat16_t const& h) { CUTLASS_HOST_DEVICE cutlass::bfloat16_t abs(cutlass::bfloat16_t const& h) { - return cutlass::bfloat16_t::bitcast(h.raw() & 0x7fffffff); + return cutlass::bfloat16_t::bitcast(h.raw() & 0x7fff); } CUTLASS_HOST_DEVICE @@ -292,9 +327,9 @@ bfloat16_t copysign(bfloat16_t const& a, bfloat16_t const& b) { // /////////////////////////////////////////////////////////////////////////////////////////////////// +#if !defined(__CUDACC_RTC__) namespace std { -#if !defined(__CUDACC_RTC__) /// Numeric limits template <> struct numeric_limits { @@ -349,9 +384,78 @@ struct numeric_limits { CUTLASS_HOST_DEVICE static cutlass::bfloat16_t denorm_min() { return cutlass::bfloat16_t::bitcast(0x1); } }; -#endif } // namespace std +#endif + +namespace cutlass { +namespace platform { + +/// Forward Declaration +template +struct numeric_limits; + +/// Numeric limits +template <> +struct numeric_limits { + static bool const is_specialized = true; + static bool const is_signed = true; + static bool const is_integer = false; + static bool const is_exact = false; + static bool const has_infinity = true; + static bool const has_quiet_NaN = true; + static bool const has_signaling_NaN = false; +#if !defined(__CUDACC_RTC__) + static std::float_denorm_style const has_denorm = std::denorm_present; +#endif + static bool const has_denorm_loss = true; +#if !defined(__CUDACC_RTC__) + static std::float_round_style const round_style = std::round_to_nearest; +#endif + static bool const is_iec559 = false; + static bool const is_bounded = true; + static bool const is_modulo = false; + static int const digits = 7; + + /// Least positive value + CUTLASS_HOST_DEVICE + static cutlass::bfloat16_t min() { return cutlass::bfloat16_t::bitcast(0x01); } + + /// Minimum finite value + CUTLASS_HOST_DEVICE + static cutlass::bfloat16_t lowest() { return cutlass::bfloat16_t::bitcast(0xff7f); } + + /// Maximum finite value + CUTLASS_HOST_DEVICE + static cutlass::bfloat16_t max() { return cutlass::bfloat16_t::bitcast(0x7f7f); } + + /// Returns smallest finite value + CUTLASS_HOST_DEVICE + static cutlass::bfloat16_t epsilon() { return cutlass::bfloat16_t::bitcast(0x1000); } + + /// Returns smallest finite value + CUTLASS_HOST_DEVICE + static cutlass::bfloat16_t round_error() { return cutlass::bfloat16_t(0.5f); } + + /// Returns smallest finite value + CUTLASS_HOST_DEVICE + static cutlass::bfloat16_t infinity() { return cutlass::bfloat16_t::bitcast(0x7f80); } + + /// Returns smallest finite value + CUTLASS_HOST_DEVICE + static cutlass::bfloat16_t quiet_NaN() { return cutlass::bfloat16_t::bitcast(0x7fff); } + + /// Returns smallest finite value + CUTLASS_HOST_DEVICE + static cutlass::bfloat16_t signaling_NaN() { return cutlass::bfloat16_t::bitcast(0x7fff); } + + /// Returns smallest finite value + CUTLASS_HOST_DEVICE + static cutlass::bfloat16_t denorm_min() { return cutlass::bfloat16_t::bitcast(0x1); } +}; + +} // namespace platform +} // namespace cutlass /////////////////////////////////////////////////////////////////////////////////////////////////// // @@ -365,114 +469,190 @@ namespace cutlass { CUTLASS_HOST_DEVICE bool operator==(bfloat16_t const& lhs, bfloat16_t const& rhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + return __heq(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16()); +#else return float(lhs) == float(rhs); +#endif } CUTLASS_HOST_DEVICE bool operator!=(bfloat16_t const& lhs, bfloat16_t const& rhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + return __hne(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16()); +#else return float(lhs) != float(rhs); +#endif } CUTLASS_HOST_DEVICE bool operator<(bfloat16_t const& lhs, bfloat16_t const& rhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + return __hlt(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16()); +#else return float(lhs) < float(rhs); +#endif } CUTLASS_HOST_DEVICE bool operator<=(bfloat16_t const& lhs, bfloat16_t const& rhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + return __hle(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16()); +#else return float(lhs) <= float(rhs); +#endif } CUTLASS_HOST_DEVICE bool operator>(bfloat16_t const& lhs, bfloat16_t const& rhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + return __hgt(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16()); +#else return float(lhs) > float(rhs); +#endif } CUTLASS_HOST_DEVICE bool operator>=(bfloat16_t const& lhs, bfloat16_t const& rhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + return __hge(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16()); +#else return float(lhs) >= float(rhs); +#endif } CUTLASS_HOST_DEVICE bfloat16_t operator+(bfloat16_t const& lhs, bfloat16_t const& rhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + return bfloat16_t(__hadd(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16())); +#else return bfloat16_t(float(lhs) + float(rhs)); +#endif } CUTLASS_HOST_DEVICE bfloat16_t operator-(bfloat16_t const& lhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + return bfloat16_t(__hneg(lhs.to_nv_bfloat16())); +#else return bfloat16_t(-float(lhs)); +#endif } CUTLASS_HOST_DEVICE bfloat16_t operator-(bfloat16_t const& lhs, bfloat16_t const& rhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + return bfloat16_t(__hsub(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16())); +#else return bfloat16_t(float(lhs) - float(rhs)); +#endif } CUTLASS_HOST_DEVICE bfloat16_t operator*(bfloat16_t const& lhs, bfloat16_t const& rhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + return bfloat16_t(__hmul(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16())); +#else return bfloat16_t(float(lhs) * float(rhs)); +#endif } CUTLASS_HOST_DEVICE bfloat16_t operator/(bfloat16_t const& lhs, bfloat16_t const& rhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + return bfloat16_t(__hdiv(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16())); +#else return bfloat16_t(float(lhs) / float(rhs)); +#endif } CUTLASS_HOST_DEVICE bfloat16_t& operator+=(bfloat16_t & lhs, bfloat16_t const& rhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + lhs = bfloat16_t(__hadd(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16())); +#else lhs = bfloat16_t(float(lhs) + float(rhs)); +#endif return lhs; } CUTLASS_HOST_DEVICE bfloat16_t& operator-=(bfloat16_t & lhs, bfloat16_t const& rhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + lhs = bfloat16_t(__hsub(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16())); +#else lhs = bfloat16_t(float(lhs) - float(rhs)); +#endif return lhs; } CUTLASS_HOST_DEVICE bfloat16_t& operator*=(bfloat16_t & lhs, bfloat16_t const& rhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + lhs = bfloat16_t(__hmul(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16())); +#else lhs = bfloat16_t(float(lhs) * float(rhs)); +#endif return lhs; } CUTLASS_HOST_DEVICE bfloat16_t& operator/=(bfloat16_t & lhs, bfloat16_t const& rhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + lhs = bfloat16_t(__hdiv(lhs.to_nv_bfloat16(), rhs.to_nv_bfloat16())); +#else lhs = bfloat16_t(float(lhs) / float(rhs)); +#endif return lhs; } CUTLASS_HOST_DEVICE bfloat16_t& operator++(bfloat16_t & lhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + lhs = bfloat16_t(__hadd(lhs.to_nv_bfloat16(), bfloat16_t(1.0f).to_nv_bfloat16())); +#else float tmp(lhs); ++tmp; lhs = bfloat16_t(tmp); +#endif return lhs; } CUTLASS_HOST_DEVICE bfloat16_t& operator--(bfloat16_t & lhs) { +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + lhs = bfloat16_t(__hsub(lhs.to_nv_bfloat16(), bfloat16_t(1.0f).to_nv_bfloat16())); +#else float tmp(lhs); --tmp; lhs = bfloat16_t(tmp); +#endif return lhs; } CUTLASS_HOST_DEVICE bfloat16_t operator++(bfloat16_t & lhs, int) { bfloat16_t ret(lhs); +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + lhs = bfloat16_t(__hadd(lhs.to_nv_bfloat16(), bfloat16_t(1.0f).to_nv_bfloat16())); +#else float tmp(lhs); tmp++; lhs = bfloat16_t(tmp); +#endif return ret; } CUTLASS_HOST_DEVICE bfloat16_t operator--(bfloat16_t & lhs, int) { bfloat16_t ret(lhs); +#if defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 800) + lhs = bfloat16_t(__hsub(lhs.to_nv_bfloat16(), bfloat16_t(1.0f).to_nv_bfloat16())); +#else float tmp(lhs); tmp--; lhs = bfloat16_t(tmp); +#endif return ret; } diff --git a/include/cutlass/blas3.h b/include/cutlass/blas3.h index 3c2df6dd6f..d41f1ee61e 100644 --- a/include/cutlass/blas3.h +++ b/include/cutlass/blas3.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -39,7 +39,9 @@ #include "cutlass/cutlass.h" #include "cutlass/array.h" +#include "cutlass/blas3_types.h" #include "cutlass/coord.h" +#include "cutlass/complex.h" #include "cutlass/functional.h" #include "cutlass/numeric_types.h" @@ -48,41 +50,7 @@ namespace cutlass { ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Enumerated type describing the type of kernel (based on input or output matrices). -enum class BlasMode { - kGemm, - kSymmetric, - kHermitian, - kTriangular, - kInvalid -}; - -/// Enumerated type describing the fill mode for matrices for BLAS functions. -enum class FillMode { - kFull, /// The entire tensor is covered. - kLower, /// The 'lower' part of a tensor is covered including diagonal - kUpper, /// The 'upper' part of a tensor is covered including diaognal - kDiagonal, /// Only diagonal elements are covered. - kNone, /// No element is covered. - kInvalid -}; -/// Enumerated type describing the diagonal property of matrices for BLAS functions. -enum class DiagType { - kNonUnit, - kUnit, - kZero, // Only used internally for computing SYMM/HEMM - kInvalid -}; - -/// Enumerated type describing the side dense matrix is in matrix equation for BLAS functions. -enum class SideMode { - kLeft, - kRight, - kInvalid -}; - -///////////////////////////////////////////////////////////////////////////////////////////////// /// Defines FillMode inversions template struct InvertFillMode; @@ -164,7 +132,7 @@ struct MantissaInBits { template <> struct MantissaInBits> { static int constexpr bits = 30; - static double constexpr error = 1.0e-15; + static double constexpr error = 1.0e-14; }; //////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/blas3_types.h b/include/cutlass/blas3_types.h new file mode 100644 index 0000000000..653b93b771 --- /dev/null +++ b/include/cutlass/blas3_types.h @@ -0,0 +1,78 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Enumerated type describing the type of kernel (based on input or output matrices). +enum class BlasMode { + kGemm, + kSymmetric, + kHermitian, + kTriangular, + kInvalid +}; + +/// Enumerated type describing the fill mode for matrices for BLAS functions. +enum class FillMode { + kFull, /// The entire tensor is covered. + kLower, /// The 'lower' part of a tensor is covered including diagonal + kUpper, /// The 'upper' part of a tensor is covered including diaognal + kDiagonal, /// Only diagonal elements are covered. + kNone, /// No element is covered. + kInvalid +}; + +/// Enumerated type describing the diagonal property of matrices for BLAS functions. +enum class DiagType { + kNonUnit, + kUnit, + kZero, // Only used internally for computing SYMM/HEMM + kInvalid +}; + +/// Enumerated type describing the side dense matrix is in matrix equation for BLAS functions. +enum class SideMode { + kLeft, + kRight, + kInvalid +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/block_striped.h b/include/cutlass/block_striped.h new file mode 100644 index 0000000000..09f3fb04fc --- /dev/null +++ b/include/cutlass/block_striped.h @@ -0,0 +1,267 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Utilities for performing block-striped access (load, store, reduce) of trivially-copyable, + statically-sized array types to global memory. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/wmma_array.h" +#include "cutlass/functional.h" +#include "cutlass/complex.h" + +namespace cutlass { + +///////////////////////////////////////////////////////////////////////////////////////////////// +// AccessWidth +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Computes the maximal power-of-two that evenly divides the size of T, capped at Limit +template < + typename T, + int Limit> +struct AccessWidth +{ + // Inductive case + template < + int ObjectBytes, /// Size of T in bytes + int AlignBytes, /// Template induction variable + bool IsAligned = /// Whether ObjectBytes is an even multiple of AlignBytes + ((AlignBytes <= Limit) && (ObjectBytes % AlignBytes == 0))> + struct Detail + { + static const int value = Detail::value; + }; + + // Base case (ObjectBytes is not an even multiple of AlignBytes) + template < + int ObjectBytes, /// Size of T in bytes + int AlignBytes> /// Template induction variable + struct Detail + { + static const int value = AlignBytes / 2; + }; + + /// The maximal power-of-two that evenly divides the size of T + static const int value = Detail< + (int) sizeof(T), + 1>::value; +}; + + + +///////////////////////////////////////////////////////////////////////////////////////////////// +// StripedAccessType +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// ReinterpretCast type for striping a trivially-copyable type in global memory +/// (Default specialization. Striping granularity is type T.) +template < + typename T, /// Data type + int TransferBytes = /// Data access width (16 byte max for global memory access on current architectures) + AccessWidth::value> +struct alignas(TransferBytes) StripedAccessType : public T +{}; + + +/// ReinterpretCast type for striping a trivially-copyable type in global memory +/// (Specialization for cutlass::Array. Striping granularity is a multiple of T.) +template < + typename T, /// Array element type + int N, /// Number of elements in array + bool RegisterSized, /// T is register-sized + int TransferBytes> /// Data access width +struct StripedAccessType< + Array, + TransferBytes> +: public AlignedArray< + T, // Element type of StripedAccessType + __NV_STD_MAX(1, TransferBytes / (int) sizeof(T)), // Number of elements T in StripedAccessType + TransferBytes> // Alignment of StripedAccessType +{}; + + +#if defined(CUTLASS_ARCH_WMMA_ENABLED) + +/// ReinterpretCast type for striping a trivially-copyable type in global memory +/// (Specialization for cutlass::WmmaFragmentArray. Striping granularity is a multiple of T.) +template< + typename Use, + int m, + int n, + int k, + typename ElementT, + typename Layout, + int kFragments, + int TransferBytes> +struct StripedAccessType< + WmmaFragmentArray, kFragments>, + TransferBytes> +: public AlignedArray< + ElementT, + __NV_STD_MAX(1, TransferBytes / (int) sizeof(ElementT)), + TransferBytes> +{}; + +#endif // if defined(CUTLASS_ARCH_WMMA_ENABLED) + + +///////////////////////////////////////////////////////////////////////////////////////////////// +// BlockStriped +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Utility for performing block-striped access (load, store) of trivially-copyable, +/// statically-sized array types to global memory +template < + int BlockThreads, + typename ArrayT, + typename AccessT = StripedAccessType > +struct BlockStriped +{ + /// Number of striped accesses + static const int kStripes = int(sizeof(ArrayT) / sizeof(AccessT)); + static_assert(kStripes > 0, "AccessT type must be smaller than or equal to ArrayT type"); + + /// Load + CUTLASS_DEVICE + static void load(ArrayT &data, ArrayT *ptr, int thread_idx) + { + AccessT *access_input = reinterpret_cast(ptr); + AccessT *access_data = reinterpret_cast(&data); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kStripes; ++i) { + access_data[i] = access_input[(BlockThreads * i) + thread_idx]; + } + } + + /// Load & Add + CUTLASS_DEVICE + static void load_add(ArrayT &data, ArrayT *ptr, int thread_idx) + { + AccessT *access_input = reinterpret_cast(ptr); + AccessT *access_data = reinterpret_cast(&data); + + plus add; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kStripes; ++i) + { + access_data[i] = add(access_data[i], access_input[(BlockThreads * i) + thread_idx]); + } + } + + /// Store + CUTLASS_DEVICE + static void store(ArrayT *ptr, const ArrayT &data, int thread_idx) + { + AccessT *access_output = reinterpret_cast(ptr); + const AccessT *access_data = reinterpret_cast(&data); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < kStripes; ++i) { + access_output[(BlockThreads * i) + thread_idx] = access_data[i]; + } + } + +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +// BlockStripedReduce +///////////////////////////////////////////////////////////////////////////////////////////////// + + +/// Utility for performing block-striped access (load, store, reduce) of trivially-copyable, +/// statically-sized array types to global memory. +/// (Default specialization) +template < + int BlockThreads, + typename ArrayT, + typename ElementT = typename StripedAccessType::Element> +struct BlockStripedReduce : + BlockStriped< + BlockThreads, + ArrayT, + ElementT> +{ + /// Reduce + CUTLASS_DEVICE + static void reduce(ArrayT *ptr, const ArrayT &data, int thread_idx) + { + cutlass::atomic_add reduce; + ElementT *access_output = reinterpret_cast(ptr); + const ElementT *access_data = reinterpret_cast(&data); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < BlockStripedReduce::kStripes; ++i) { + reduce(access_output + (BlockThreads * i) + thread_idx, access_data[i]); + } + } +}; + + +/// Utility for performing block-striped access (load, store, reduce) of trivially-copyable, +/// statically-sized array types to global memory. +/// (Specialization for half_t. Uses half2 vectorized-reduction.) +template < + int BlockThreads, + typename ArrayT> +struct BlockStripedReduce : + BlockStriped< + BlockThreads, + ArrayT, + half2> +{ + static_assert(BlockStripedReduce::kStripes % 2 == 0, "Array of half must be even number in length"); + + /// Reduce + CUTLASS_DEVICE + static void reduce(ArrayT *ptr, const ArrayT &data, int thread_idx) + { + cutlass::atomic_add reduce; + half2 *access_output = reinterpret_cast(ptr); + const half2 *access_data = reinterpret_cast(&data); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < BlockStripedReduce::kStripes; ++i) + { + reduce(access_output + (BlockThreads * i) + thread_idx, access_data[i]); + } + } +}; + + +} // namespace cutlass + diff --git a/include/cutlass/cluster_launch.hpp b/include/cutlass/cluster_launch.hpp new file mode 100644 index 0000000000..a0fa22b6bb --- /dev/null +++ b/include/cutlass/cluster_launch.hpp @@ -0,0 +1,275 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief PTX for TMA Tensor Memory Access operators on memory added for SM90 +*/ + +#pragma once + +#include +#include "cutlass/cutlass.h" +#include "cutlass/trace.h" +#if defined(__CUDACC_RTC__) +#include +#else +#include +#include +#endif + +#if ((__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 8))) +# define CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED +#endif + +namespace cutlass { + +#ifndef NDEBUG +#define Return_Status(cudaError_t_status) \ + if (cudaError_t_status != cudaSuccess) { \ + fprintf(stderr, \ + "[ ERROR: CUDA Runtime ] %s:%d: %s\n", \ + __FILE__, \ + __LINE__, \ + cudaGetErrorString(cudaError_t_status)); \ + return Status::kInvalid; \ + } else { \ + return Status::kSuccess; \ + } +#else +#define Return_Status(cudaError_t_status) \ + if (cudaError_t_status != cudaSuccess) { \ + return Status::kInvalid; \ + } else { \ + return Status::kSuccess; \ + } +#endif + +struct ClusterLauncher { + constexpr static int MaxClusterSize = 32; + + // Check for hardware compatibility + static inline CUTLASS_HOST + Status check_cluster_dims(dim3 grid, dim3 cluster) { + if (((cluster.x * cluster.y * cluster.z) <= MaxClusterSize) && + (grid.x % cluster.x == 0) && (grid.y % cluster.y == 0) && (grid.z % cluster.z == 0)) { + return Status::kSuccess; + } + else { + CUTLASS_TRACE_HOST("ClusterLauncher: Invalid cluster configuration -- aborting launch."); + return Status::kInvalid; + } + } + + static inline CUTLASS_HOST + Status +#if defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED) + init(void const* kernel_function) +#else + init(void const* /* kernel_function */) +#endif + { +#if defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED) +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + if (kernel_function == nullptr) { + CUTLASS_TRACE_HOST("kernel_function is null"); + return Status::kInvalid; + } + CUTLASS_TRACE_HOST("Checking previous error state before calling cudaFuncSetAttribute"); + cudaError_t prevStatus = cudaGetLastError(); + if (prevStatus != cudaSuccess) { + fprintf(stderr, + "[ ERROR: CUDA Runtime ] %s:%d: %s\n", + __FILE__, + __LINE__, + cudaGetErrorString(prevStatus)); + return Status::kInvalid; + } + CUTLASS_TRACE_HOST("Calling cudaFuncSetAttribute"); +#endif + // This attribute was added in CUDA 11.8. + cudaError_t status = + cudaFuncSetAttribute( + kernel_function, cudaFuncAttributeNonPortableClusterSizeAllowed, 1); + Return_Status(status); +#else + return Status::kInvalid; +#endif + } + + // This is the method we expect to use going forward + static inline CUTLASS_HOST + Status launch( + dim3 const grid_dims, + dim3 const cluster_dims, + dim3 const block_dims, + size_t const smem_size, + cudaStream_t cuda_stream, + void const* kernel, + void** kernel_params, + bool launch_with_pdl = false) { +#if defined(CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED) + if (check_cluster_dims(grid_dims, cluster_dims) != Status::kSuccess) { + CUTLASS_TRACE_HOST("ClusterLauncher: check_cluster_dims() failed. Aborting."); + return Status::kInvalid; + } + + auto init_status = init(kernel); + if (init_status != Status::kSuccess) { + CUTLASS_TRACE_HOST("ClusterLauncher: init(kernel) failed with status " << int(init_status) << ". Aborting."); + return Status::kInvalid; + } + + cudaLaunchConfig_t launch_config; + launch_config.gridDim = {grid_dims.x, grid_dims.y, grid_dims.z}; + launch_config.blockDim = {block_dims.x, block_dims.y, block_dims.z}; + launch_config.dynamicSmemBytes = smem_size; + launch_config.stream = cuda_stream; + + cudaLaunchAttribute launch_attribute[2]; + + launch_attribute[0].id = cudaLaunchAttributeClusterDimension; + launch_attribute[0].val.clusterDim.x = cluster_dims.x; + launch_attribute[0].val.clusterDim.y = cluster_dims.y; + launch_attribute[0].val.clusterDim.z = cluster_dims.z; + + launch_attribute[1].id = cudaLaunchAttributeProgrammaticStreamSerialization; + launch_attribute[1].val.programmaticStreamSerializationAllowed = 1; + + launch_config.numAttrs = launch_with_pdl ? 2 : 1; + + launch_config.attrs = launch_attribute; + + CUTLASS_TRACE_HOST("ClusterLauncher: Launching GPC_CLUSTER_GRID GridDims = " + "(" << grid_dims.x << ", " << grid_dims.y << ", " << grid_dims.z << "), " + "And ClusterDims = " + "(" << cluster_dims.x << ", " << cluster_dims.y << ", " << cluster_dims.z << ")\n"); + + cutlass::arch::synclog_setup(); + cudaError_t status = cudaLaunchKernelExC(&launch_config, kernel, kernel_params); + Return_Status(status); +#else + CUTLASS_TRACE_HOST("ClusterLauncher: CUTLASS_SM90_CLUSTER_LAUNCH_ENABLED not defined! Aborting cluster launch."); + return Status::kInvalid; +#endif + } + +}; + +namespace detail { + +template +void* checked_addressof(Arg&& arg) { + static_assert(! std::is_rvalue_reference_v || ! std::is_const_v, "You cannot take the address of a const rvalue reference (const T&&)."); + // We use std::addressof to ensure we get the address, + // in case the type has an overloaded operator&. + // Note that this precludes `const T&&` references. + return const_cast(reinterpret_cast(std::addressof(arg))); +} + +} // namespace detail + +//! Parameters for launch_on_cluster (see below). +struct ClusterLaunchParams { + //! Grid dimensions + dim3 grid_dims{1, 1, 1}; + + //! Block dimensions + dim3 block_dims{1, 1, 1}; + + //! Cluster dimensions + dim3 cluster_dims{1, 1, 1}; + + //! Number of bytes required for the kernel's shared memory. + int smem_size_in_bytes = 0; + + //! CUDA stream on which to launch the kernel. + cudaStream_t cuda_stream = nullptr; +}; + +/// @brief Launch the kernel on the stream using cluster launch. +/// +/// @param params Cluster launch parameters (see above). +/// @param kernel_ptr Pointer to the kernel function (see example). +/// @param args Zero or more arguments to pass to the kernel. +/// +/// @tparam Args Types of the arguments passed to the kernel. +/// Don't specify this/these template argument(s) explicitly. +/// +/// @return Status::Success on success, else an error code. +/// +/// @code +/// template +/// __global__ void kernel(A a, B b, C c); +/// +/// X x = get_x(); +/// Y y = get_y(); +/// Z z = get_z(); +/// +/// void const* kernel_ptr = +/// const_cast(reinterpret_cast( +/// &kernel)); +/// auto status = launch_kernel_on_cluster( +/// {grid_dims, block_dims, cluster_dims, sizeof(SharedMemory)}, +/// kernel_ptr, x, y, z); +/// @endcode +template +CUTLASS_HOST cutlass::Status +launch_kernel_on_cluster(const ClusterLaunchParams& params, + void const* kernel_ptr, + Args&& ... args) +{ + // Unfortunately, we find ourselves needing to pass in + // the parameters as an array of raw pointers. + if constexpr (sizeof...(Args) == 0) { + return cutlass::ClusterLauncher::launch( + params.grid_dims, + params.cluster_dims, + params.block_dims, + params.smem_size_in_bytes, + params.cuda_stream, + kernel_ptr, nullptr); + } + else { + void* kernel_params[sizeof...(Args)] = { + detail::checked_addressof(std::forward(args))... + }; + return cutlass::ClusterLauncher::launch( + params.grid_dims, + params.cluster_dims, + params.block_dims, + params.smem_size_in_bytes, + params.cuda_stream, + kernel_ptr, + kernel_params); + } +} + +} // namespace cutlass diff --git a/include/cutlass/complex.h b/include/cutlass/complex.h index aeccd2db1d..6d0bf31df6 100644 --- a/include/cutlass/complex.h +++ b/include/cutlass/complex.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -28,10 +28,13 @@ * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. * **************************************************************************************************/ + #pragma once #include +#include + #if defined(__CUDACC_RTC__) #include #else @@ -39,11 +42,11 @@ #endif #include "cutlass/cutlass.h" -#include "cutlass/half.h" +#include "cutlass/functional.h" +#include "cutlass/platform/platform.h" #include "cutlass/real.h" -#include "cutlass/bfloat16.h" -#include "cutlass/tfloat32.h" +#include "cutlass/numeric_types.h" #include "cutlass/fast_math.h" @@ -53,8 +56,7 @@ namespace cutlass { -////////////////////////////////////////////////////////////////////////////////////////////////// - +///////////////////////////////////////////////////////////////////////////////////////////////// /// Enumeraed type describing a transformation on a complex value. enum class ComplexTransform { kNone, @@ -116,6 +118,18 @@ double const &imag(cuDoubleComplex const &z) { return z.y; } /// Returns the imaginary part of the complex number CUTLASS_HOST_DEVICE double &imag(cuDoubleComplex &z) { return z.y; } + +// Returns the conjugate of the complex number +CUTLASS_HOST_DEVICE cuFloatComplex +conj(cuFloatComplex const& z) { + return make_cuFloatComplex(z.x, -z.y); +} + +// Returns the conjugate of the complex number +CUTLASS_HOST_DEVICE cuDoubleComplex +conj(cuDoubleComplex const& z) { + return make_cuDoubleComplex(z.x, -z.y); +} #endif /////////////////////////////////////////////////////////////////////////////////////////////////// @@ -128,6 +142,7 @@ class complex { public: /// Type alias for scalar type + using value_type = T; private: // @@ -146,15 +161,18 @@ class complex // Methods // -/// Constructor + /// Default constructor + complex() = default; + + /// Constructor CUTLASS_HOST_DEVICE - complex(T r = T(0)) : _real(r), _imag(T(0)) {} + complex(T r) : _real(r), _imag(T(0)) {} -/// Constructor + /// Constructor CUTLASS_HOST_DEVICE complex(T r, T i) : _real(r), _imag(i) {} - // -/// Constructor + + /// Constructor template CUTLASS_HOST_DEVICE complex(complex const &z) : _real(static_cast(z.real())), _imag(static_cast(z.imag())) {} @@ -170,16 +188,6 @@ class complex complex(cuDoubleComplex const &z) : _real(static_cast(cuCreal(z))), _imag(static_cast(cuCimag(z))) {} #endif - /// Assignment - template - CUTLASS_HOST_DEVICE - complex& operator=(complex const &z) - { - _real = static_cast(z.real()); - _imag = static_cast(z.imag()); - return *this; - } - /// Equality operator CUTLASS_HOST_DEVICE bool operator==(complex const &rhs) const { return this->real() == rhs.real() && this->imag() == rhs.imag(); @@ -196,6 +204,24 @@ class complex return complex(this->real() + rhs.real(), this->imag() + rhs.imag()); } + /// Reduction into memory address. Components may update out of order. + template + CUTLASS_DEVICE void red(complex *ptr) const { + static_assert(platform::is_same::value, "Component type must match"); + cutlass::atomic_add reduce; + reduce(&ptr->_real, _real); + reduce(&ptr->_imag, _imag); + } + + /// Reduction into memory address. Components may update out of order. (Half specialization) + CUTLASS_DEVICE void red(complex *ptr) const { + static_assert(platform::is_same::value, "Component type must match"); + half2 *h2_ptr = reinterpret_cast(ptr); + half2 h2_data = reinterpret_cast(*this); + cutlass::atomic_add reduce; + reduce(h2_ptr, h2_data); + } + /// Subtraction template CUTLASS_HOST_DEVICE complex operator-(complex const &rhs) const { @@ -283,6 +309,13 @@ class complex CUTLASS_HOST_DEVICE T &imag() { return _imag; } + /// Set the real part of the complex number + CUTLASS_HOST_DEVICE + void real(T real) { _real = real; } + + /// Set the imaginary part of the complex number + CUTLASS_HOST_DEVICE + void imag(T imag) { _imag = imag; } #if !defined(__CUDACC_RTC__) /// Converts to cuFloatComplex @@ -295,60 +328,90 @@ class complex #endif }; +// Complex conjugate +template +CUTLASS_HOST_DEVICE complex conj(complex const& z) { + return {z.real(), -z.imag()}; +} + /////////////////////////////////////////////////////////////////////////////////////////////////// // // Accessors for complex template // -/// Returns the real part of the complex number -template -CUTLASS_HOST_DEVICE T const &real(complex const &z) { - return z.real(); -} +// Nonmember real and imag need to work for non-complex numbers too. +// That means cutlass::complex, std::complex, cuda::std::complex, and +// any user-defined complex number type that looks like std::complex. +// It's reasonable to assume that a "complex number type" has +// zero-argument real() and imag() member functions returning +// non-void. While cuFloatComplex and cuDoubleComplex lack those +// member functions, one-argument nonmember real and imag overloads +// for those types are defined above. -/// Returns the real part of the complex number -template -CUTLASS_HOST_DEVICE T &real(complex &z) { - return z.real(); -} +namespace detail { -/// Returns the imaginary part of the complex number -template -CUTLASS_HOST_DEVICE T const &imag(complex const &z) { - return z.imag(); -} +template +struct has_zero_argument_real_member_function : + cutlass::platform::false_type +{}; -/// Returns the imaginary part of the complex number template -CUTLASS_HOST_DEVICE T &imag(complex &z) { - return z.imag(); -} +struct has_zero_argument_real_member_function().real()) + > + > +> : cutlass::platform::true_type +{}; -/// Returns the real part of the real number template -CUTLASS_HOST_DEVICE T const &real(T const &r) { - return r; -} +constexpr bool has_zero_argument_real_member_function_v = + has_zero_argument_real_member_function::value; -/// Returns the real part of the real number -template -CUTLASS_HOST_DEVICE T &real(T &r) { - return r; -} +template +struct has_zero_argument_imag_member_function : + cutlass::platform::false_type +{}; -/// Returns the imaginary part of the real number template -CUTLASS_HOST_DEVICE T const &imag(T const &r) { - return T(); -} +struct has_zero_argument_imag_member_function().imag()) + > + > +> : cutlass::platform::true_type +{}; -/// Returns the imaginary part of the complex number template -CUTLASS_HOST_DEVICE T &imag(T &r) { - return T(); -} +constexpr bool has_zero_argument_imag_member_function_v = + has_zero_argument_imag_member_function::value; + +} // namespace detail +template +CUTLASS_HOST_DEVICE auto real(T z) { + if constexpr (detail::has_zero_argument_real_member_function_v) { + return z.real(); + } else { + return z; + } +} + +template +CUTLASS_HOST_DEVICE auto imag(T z) { + if constexpr (detail::has_zero_argument_imag_member_function_v) { + return z.imag(); + } else { + // Imaginary part of a non-complex input has the same type as the + // input, and its value is zero. CUTLASS assumes in this case + // that value-initializing T is well-formed and results in zero. + return T{}; + } +} + // // Output operators // @@ -375,10 +438,36 @@ std::ostream &operator<<(std::ostream &out, complex const &z) { // Non-member functions defined for complex numbers // -/// Returns the magnitude of the complex number +// abs returns the magnitude of the complex number. + +CUTLASS_HOST_DEVICE float abs(complex const &z) { + return ::hypot(z.real(), z.imag()); +} + +CUTLASS_HOST_DEVICE double abs(complex const &z) { + return ::hypot(z.real(), z.imag()); +} + +// In theory, it would make sense to add a complex +// specialization of abs here, since hypot works for long double too. +// In practice, long double doesn't have a portable number of bits or +// behavior, so users who care about higher-precision floating-point +// computation should probably insist on an actual FP128 type. + template CUTLASS_HOST_DEVICE T abs(complex const &z) { - return sqrt(norm(z)); + // cutlass::complex permits all kinds of T, including types that + // don't have NaN. For a generic floating-point type with Inf + // and/or NaN, LAPACK's DLAPY2 algorithm would make sense, as it + // would handle issues like avoiding unwarranted overflow if + // z.real() or z.imag() is slightly bigger than the square root of + // the max finite number. That could be a future improvement; for + // now, the code just uses the naive algorithm. + // + // Use the "swap two-step" idiom so that argument-dependent lookup + // can find any CUTLASS-specific overloads. + using cutlass::sqrt; + return sqrt(z.real() * z.real() + z.imag() * z.imag()); } /// Returns the magnitude of the complex number @@ -414,33 +503,70 @@ CUTLASS_HOST_DEVICE R norm_accumulate(T const &x, R const & accumulator) { /// Norm accumulate specialized for complex types template CUTLASS_HOST_DEVICE R norm_accumulate(complex const &z, R const &accumulator) { - return accumulator + static_cast(real(z)) * static_cast(real(z)) + + return accumulator + static_cast(real(z)) * static_cast(real(z)) + static_cast(imag(z)) * static_cast(imag(z)); } -/// Returns the complex conjugate -CUTLASS_HOST_DEVICE float conj(float const &z) { - return z; +namespace detail { + +template +CUTLASS_HOST_DEVICE T conj_impl(T const& z, cutlass::platform::true_type) { + return conj(z); } -/// Returns the complex conjugate -CUTLASS_HOST_DEVICE double conj(double const &z) { +template +CUTLASS_HOST_DEVICE T conj_impl(T const& z, cutlass::platform::false_type) { return z; } -/// Returns the complex conjugate -template -CUTLASS_HOST_DEVICE complex conj(complex const &z) { - return complex(real(z), -imag(z)); +template +CUTLASS_HOST_DEVICE T conj_impl(T const& z) { + constexpr bool use_unqualified_conj = + ! cutlass::platform::is_arithmetic_v && + ! detail::has_cutlass_conj_v && + detail::has_unqualified_conj_v; + return conj_impl(z, cutlass::platform::bool_constant{}); } -/// Indentity transform for non-complex types -template -CUTLASS_HOST_DEVICE T conj(T const &z) { - static_assert( !platform::is_same::value && - !platform::is_same::value && - !platform::is_same>::value && - !platform::is_same>::value, "May not be a complex data type"); - return z; + +} // namespace detail + +// Return the complex conjugate of the input. +// +// This MUST be a function and not a function object, because it may +// be common practice for downstream types to define specifically +// cutlass::conj overloads, instead of overloads in their namespace. +// +// As a result of this being a function and not a function object, +// CUTLASS code needs to declare "using cutlass::conj;" in scope and +// then call this function unqualified, just like std::swap. +// +// If an overload already exists for cutlass::conj(T), that overload +// will be called instead of this one. Otherwise: +// +// 1. for arithmetic types, return z; +// +// 2. for types where (namespace-unqualified) conj(z) is well formed +// and cutlass::conj(z) is NOT well formed, return conj(z); and, +// +// 3. for everything else, return z. +// +// Regarding (1), the C++ Standard Library makes std::conj always +// return std::complex, even for (noncomplex) arithmetic types. +// cutlass::conj(T t) needs to return type T. This follows the +// convention of linear algebra software like the BLAS, where +// "conjugate transpose" means the same thing as "transpose" for a +// matrix of noncomplex numbers. +// +// Case (2) covers std::complex, cuda::std::complex, and non-Standard +// (including user-defined) complex number types (for which "conj(z)" +// is findable via argument-dependent lookup, but does not live in the +// cutlass namespace). It excludes cutlass::conj(z) in order to +// prevent infinite recursion. +// +// Case (3) covers non-Standard non-complex number types. +template +CUTLASS_HOST_DEVICE T conj(T const& z) { + return detail::conj_impl(z); } /// Projects the complex number z onto the Riemann sphere @@ -494,24 +620,24 @@ CUTLASS_HOST_DEVICE complex sin(complex const &z) { return (exp(-z) - exp(z)) * complex(T(0), T(1) / T(2)); } -/// Comparison +/// Comparison template CUTLASS_HOST_DEVICE bool operator<(complex const &lhs, complex const &rhs) { - //TODO - return true; + return true; } ////////////////////////////////////////////////////////////////////////////////////////////////// /// Partial specialization for complex-valued type. template -struct RealType< complex > { +struct RealType< complex > +{ using Type = T; /// Number of elements static int const kExtent = 2; -CUTLASS_HOST_DEVICE + CUTLASS_HOST_DEVICE static complex from_real(double x) { return complex(static_cast(x)); } @@ -549,6 +675,147 @@ struct is_complex> { static bool const value = true; }; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// functional.h numeric specializations +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Squares with optional conversion +template +struct magnitude_squared, Output> { + CUTLASS_HOST_DEVICE + Output operator()(complex lhs) const { + multiplies mul_op; + + Output y_r = Output(lhs.real()); + Output y_i = Output(lhs.imag()); + + return mul_op(y_r, y_r) + mul_op(y_i, y_i); + } +}; + +/// Fused multiply-add +template +struct multiply_add, complex, complex> { + CUTLASS_HOST_DEVICE + complex operator()( + complex const &a, + complex const &b, + complex const &c) const { + + T real = c.real(); + T imag = c.imag(); + + real += a.real() * b.real(); + real += -a.imag() * b.imag(); + imag += a.real() * b.imag(); + imag += a.imag () * b.real(); + + return complex{ + real, + imag + }; + } +}; + +/// Fused multiply-add +template +struct multiply_add, T, complex> { + CUTLASS_HOST_DEVICE + complex operator()( + complex const &a, + T const &b, + complex const &c) const { + + T real = c.real(); + T imag = c.imag(); + + real += a.real() * b; + imag += a.imag () * b; + + return complex{ + real, + imag + }; + } +}; + +/// Fused multiply-add +template +struct multiply_add, complex> { + CUTLASS_HOST_DEVICE + complex operator()( + T const &a, + complex const &b, + complex const &c) const { + + T real = c.real(); + T imag = c.imag(); + + real += a * b.real(); + imag += a * b.imag(); + + return complex{ + real, + imag + }; + } +}; + +/// Conjugate +template +struct conjugate> { + CUTLASS_HOST_DEVICE + complex operator()(complex const &a) const { + // Invoke the complex overload specifically, rather than + // wasting the compiler's effort on overload resolution. + return cutlass::conj(a); + } +}; + +#if ! defined(__CUDACC_RTC__) +template <> +struct conjugate { + CUTLASS_HOST_DEVICE + cuFloatComplex operator()(cuFloatComplex const& z) const { + return make_cuFloatComplex(z.x, -z.y); + } +}; + +template <> +struct conjugate { + CUTLASS_HOST_DEVICE + cuDoubleComplex operator()(cuDoubleComplex const& z) const { + return make_cuDoubleComplex(z.x, -z.y); + } +}; +#endif + +/// Computes the square of a difference with optional conversion +template +struct magnitude_squared_difference, Output> { + CUTLASS_HOST_DEVICE + Output operator()(complex lhs, complex rhs) const { + multiplies mul_op; + + Output y_r = Output(lhs.real()) - Output(rhs.real()); + Output y_i = Output(lhs.imag()) - Output(rhs.imag()); + + return mul_op(y_r, y_r) + mul_op(y_i, y_i); + } +}; + +/// Reduces value into the data pointed to by ptr (complex specialization) +template +struct atomic_add> { + CUTLASS_DEVICE + void operator()(complex *ptr, const complex &data) + { + data.red(ptr); + } +}; + + ////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace cutlass diff --git a/include/cutlass/constants.h b/include/cutlass/constants.h index abb7cab438..49d96045aa 100644 --- a/include/cutlass/constants.h +++ b/include/cutlass/constants.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/include/cutlass/conv/collective/builders/sm90_common.inl b/include/cutlass/conv/collective/builders/sm90_common.inl new file mode 100644 index 0000000000..526db83edf --- /dev/null +++ b/include/cutlass/conv/collective/builders/sm90_common.inl @@ -0,0 +1,96 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/layout/tensor.h" +#include "cutlass/arch/mma.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/dispatch_policy.hpp" +#include "cutlass/detail/layout.hpp" +#include "cutlass/gemm/collective/builders/sm90_common.inl" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::conv::collective::detail { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Maps a rank-1 cute::Shape<> representing the cluster shape on to the IM2COL TMA atom that should be used with it +template +constexpr auto +sm90_cluster_shape_to_im2col_tma_atom(UnimodalClusterShape unimodal_cluster_shape) { + static_assert(cute::rank(unimodal_cluster_shape) == 1, + "Use this function to figure out TMA for each mode individually."); + + if constexpr (cute::size(unimodal_cluster_shape) == 1) { + return cute::SM90_TMA_LOAD_IM2COL{}; + } + else { + return cute::SM90_TMA_LOAD_IM2COL_MULTICAST{}; + } +} + +// Collective tile traits struct that serves as a type list containing a tensor's mem layouts and atoms for the +template< + class GmemTiledCopy_, + class SmemLayout_, + class SmemCopyAtom_ = void +> +struct Sm90ImplicitGemmTileTraits { + using GmemTiledCopy = GmemTiledCopy_; + using SmemLayout = SmemLayout_; + using SmemCopyAtom = SmemCopyAtom_; +}; + +// Accepts a cutlass::layout::Tensor tag and computes the corresponding spatial dimension count +template +constexpr int +gmem_layout_tags_to_spatial_dims() { + static_assert(cute::is_same_v); + if constexpr (cute::is_same_v) { + return 1; + } + else if constexpr (cute::is_same_v) { + return 2; + } + else if constexpr (cute::is_same_v) { + return 3; + } + else { + static_assert(cutlass::detail::dependent_false); + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::conv::collective::detail + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/collective/builders/sm90_gmma_builder.inl b/include/cutlass/conv/collective/builders/sm90_gmma_builder.inl new file mode 100644 index 0000000000..a08209efb6 --- /dev/null +++ b/include/cutlass/conv/collective/builders/sm90_gmma_builder.inl @@ -0,0 +1,257 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/conv/collective/builders/sm90_common.inl" + +// SM90 Collective Builders should be used only starting CUDA 12.0 +#if (__CUDACC_VER_MAJOR__ >= 12) +#define CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::conv::collective { +using namespace cute; + +namespace detail { + +// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. +template +constexpr int +compute_stage_count_or_override(StageCount stage_count) { + return stages; +} + +// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. +template +constexpr int +compute_stage_count_or_override(cute::Int stage_count) { + return stages; +} + +// Returns the maximum number of smem tiles that can be used with a given smem capacity, or overrides with manual count. +template +constexpr int +compute_stage_count_or_override(StageCountAutoCarveout stage_count) { + constexpr auto mainloop_pipeline_bytes = sizeof(typename cutlass::PipelineTmaAsync<1>::SharedStorage); + constexpr auto a_bits = cute::sizeof_bits_v; + constexpr auto b_bits = cute::sizeof_bits_v; + constexpr int stage_bytes = + cutlass::bits_to_bytes(a_bits * size<0>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) + + cutlass::bits_to_bytes(b_bits * size<1>(TileShapeMNK{}) * size<2>(TileShapeMNK{})) + + static_cast(mainloop_pipeline_bytes); + + return (CapacityBytes - carveout_bytes) / stage_bytes; +} + +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA_TMA_WS_SS_FPROP +template < + conv::Operator ConvOp, + class ElementA, + class GmemLayoutA, + int AlignmentA, + class ElementB, + class GmemLayoutB, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class KernelScheduleType +> +struct CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + ConvOp, + ElementA, + GmemLayoutA, + AlignmentA, + ElementB, + GmemLayoutB, + AlignmentB, + ElementAccumulator, + TileShape_MNK, + ClusterShape_MNK, + StageCountType, + KernelScheduleType, + cute::enable_if_t || + cute::is_same_v || + cute::is_same_v> +> { + static_assert(is_static::value); + static_assert(is_static::value); +#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED + static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); +#endif + static_assert(cutlass::gemm::collective::detail::is_aligned(), + "Should meet TMA alignment requirement\n"); + + // For fp32 types, map to tf32 MMA value type + using ElementAMma = cute::conditional_t, tfloat32_t, ElementA>; + using ElementBMma = cute::conditional_t, tfloat32_t, ElementB>; + + // For fprop, majorA = K, major B = K; + // For wgrad, majorA = MN, major B = MN; + // For dgrad, majorA = K, major B = MN; + static constexpr cute::GMMA::Major GmmaMajorA = + (ConvOp == conv::Operator::kWgrad) ? cute::GMMA::Major::MN : cute::GMMA::Major::K; + static constexpr cute::GMMA::Major GmmaMajorB = + (ConvOp == conv::Operator::kFprop) ? cute::GMMA::Major::K : cute::GMMA::Major::MN; + + using AtomLayoutMNK = cute::conditional_t, + Layout>, Layout>>; + + using TiledMma = decltype(cute::make_tiled_mma(cute::GMMA::ss_op_selector< + ElementAMma, ElementBMma, ElementAccumulator, TileShape_MNK, GmmaMajorA, GmmaMajorB>(), AtomLayoutMNK{})); + + // For wgrad kernel, tensor A uses tma tiled mode and tensor B uses tma im2col mode. + using GmemTiledCopyA = cute::conditional_t(ClusterShape_MNK{}))), + decltype(cutlass::conv::collective::detail::sm90_cluster_shape_to_im2col_tma_atom(cute::shape<1>(ClusterShape_MNK{})))>; + using GmemTiledCopyB = cute::conditional_t(ClusterShape_MNK{}))), + decltype(cutlass::gemm::collective::detail::sm90_cluster_shape_to_tma_atom(cute::shape<0>(ClusterShape_MNK{})))>; + + using SmemLayoutAtomA = decltype(cutlass::gemm::collective::detail::ss_smem_selector< + GmmaMajorA, ElementAMma, decltype(cute::get<0>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + using SmemLayoutAtomB = decltype(cutlass::gemm::collective::detail::ss_smem_selector< + GmmaMajorB, ElementBMma, decltype(cute::get<1>(TileShape_MNK{})), decltype(cute::get<2>(TileShape_MNK{}))>()); + + static constexpr int PipelineStages = detail::compute_stage_count_or_override(StageCountType{}); + + using SmemLayoutA = decltype(tile_to_shape( + SmemLayoutAtomA{}, + make_shape(shape<0>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}), + Step<_2,_1,_3>{})); + using SmemLayoutB = decltype(tile_to_shape( + SmemLayoutAtomB{}, + make_shape(shape<1>(TileShape_MNK{}), shape<2>(TileShape_MNK{}), Int{}), + Step<_2,_1,_3>{})); + + constexpr static int NumSpatialDimensions = cutlass::conv::collective::detail::gmem_layout_tags_to_spatial_dims(); + + using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedImplicitGemm< + ConvOp, PipelineStages, NumSpatialDimensions, ClusterShape_MNK, KernelScheduleType>; + + using CollectiveOp = CollectiveConv< + DispatchPolicy, + TileShape_MNK, + ElementA, + ElementB, + TiledMma, + detail::Sm90ImplicitGemmTileTraits, + detail::Sm90ImplicitGemmTileTraits + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// GMMA auto kernel schedule +template < + conv::Operator ConvOp, + class ElementA, + class GmemLayoutA, + int AlignmentA, + class ElementB, + class GmemLayoutB, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class KernelScheduleType +> +struct CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + ConvOp, + ElementA, + GmemLayoutA, + AlignmentA, + ElementB, + GmemLayoutB, + AlignmentB, + ElementAccumulator, + TileShape_MNK, + ClusterShape_MNK, + StageCountType, + KernelScheduleType, + cute::enable_if_t> +> { + static_assert(is_static::value); + static_assert(is_static::value); +#ifndef CUTLASS_SM90_COLLECTIVE_BUILDER_SUPPORTED + static_assert(cutlass::detail::dependent_false, "Unsupported Toolkit for SM90 Collective Builder\n"); +#endif + +/* +#if ((__CUDACC_VER_MAJOR__ > 12) || ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ >= 1))) + // Cooperative schedule performs best for CUDA Toolkits with version >= 12.1 + + // For TileShape_M == 64, choosing KernelTmaWarpSpecialized as the KernelSchedule + // Since KernelTmaWarpSpecializedCooperative requires TileShape_M to be at least 128 + using KernelWarpSpecializedSchedule = cute::conditional_t(TileShape_MNK{}) == Int<64>{}, + KernelImplicitTmaWarpSpecializedSm90PingPong, KernelImplicitTmaWarpSpecializedSm90Cooperative>; +#else + using KernelWarpSpecializedSchedule = KernelImplicitTmaWarpSpecializedSm90; +#endif +*/ + using KernelWarpSpecializedSchedule = KernelImplicitTmaWarpSpecializedSm90; + + using CollectiveOp = typename CollectiveBuilder< + arch::Sm90, + arch::OpClassTensorOp, + ConvOp, + ElementA, + GmemLayoutA, + AlignmentA, + ElementB, + GmemLayoutB, + AlignmentB, + ElementAccumulator, + TileShape_MNK, + ClusterShape_MNK, + StageCountType, + KernelWarpSpecializedSchedule + >::CollectiveOp; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::conv::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/collective/collective_builder.hpp b/include/cutlass/conv/collective/collective_builder.hpp new file mode 100644 index 0000000000..9d6a16c0db --- /dev/null +++ b/include/cutlass/conv/collective/collective_builder.hpp @@ -0,0 +1,93 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/detail/dependent_false.hpp" +#include "cutlass/conv/collective/collective_conv.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::conv::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Used to specify stage counts or dispatch to automatic computation of stage count +template +struct StageCount { + static constexpr int value = num_stages; + + StageCount() = default; + explicit StageCount(cute::Int) {} +}; + +template +struct StageCountAutoCarveout { + static constexpr int bytes = carveout_bytes; + + StageCountAutoCarveout() = default; + explicit StageCountAutoCarveout(cute::Int) {} +}; + +// Used to automatically let the builder pick the kernel schedule. +// Can be overridden with kernel schedule tags in cutlass/conv/dispatch_policy.hpp +struct KernelScheduleAuto {}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class ArchTag, + class OpClass, + conv::Operator, + class ElementA, + class GmemLayoutA, + int AlignmentA, + class ElementB, + class GmemLayoutB, + int AlignmentB, + class ElementAccumulator, + class TileShape_MNK, + class ClusterShape_MNK, + class StageCountType, + class KernelScheduleType, + class Enable = void +> +struct CollectiveBuilder { + static_assert(cutlass::detail::dependent_false, "Could not build a collective for given parameters."); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::conv::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "builders/sm90_gmma_builder.inl" +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/collective/collective_conv.hpp b/include/cutlass/conv/collective/collective_conv.hpp new file mode 100644 index 0000000000..d187b5ecee --- /dev/null +++ b/include/cutlass/conv/collective/collective_conv.hpp @@ -0,0 +1,62 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/detail/dependent_false.hpp" +#include "cutlass/conv/collective/detail.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::conv::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class DispatchPolicy, + class TileShape, + class ElementA, + class ElementB, + class TiledMma, + class TileTraitsA, + class TileTraitsB +> +struct CollectiveConv { + static_assert(cutlass::detail::dependent_false, "Could not find a mainloop specialization."); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::conv::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "sm90_implicit_gemm_gmma_ss_warpspecialized.hpp" +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/collective/detail.hpp b/include/cutlass/conv/collective/detail.hpp new file mode 100644 index 0000000000..ac272c8e20 --- /dev/null +++ b/include/cutlass/conv/collective/detail.hpp @@ -0,0 +1,254 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/conv/convnd_problem_shape.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::conv::collective::detail { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Construct the stride types for conv collectives based on the dispatch policy, strides 64b by default +template +constexpr auto +sm90_dispatch_policy_to_stride_A() { + if constexpr (DispatchPolicy::ConvOp == conv::Operator::kFprop) { + // Maps to modes ((w,n), C) + if constexpr (DispatchPolicy::NumSpatialDimensions == 1) { + return cute::Stride, + cute::Int<1>>{}; + } + // Maps to modes ((w,h,n), C) + else if constexpr (DispatchPolicy::NumSpatialDimensions == 2) { + return cute::Stride, + cute::Int<1>>{}; + } + // Maps to modes ((w,h,d,n), C) + else if constexpr (DispatchPolicy::NumSpatialDimensions == 3) { + return cute::Stride, + cute::Int<1>>{}; + } + // error dims assert + else { + static_assert(cutlass::detail::dependent_false, "Unsupported spatial dim count."); + } + } + else if constexpr (DispatchPolicy::ConvOp == conv::Operator::kWgrad) { + // Maps to modes (k, nq/npq/nzpq) + if constexpr (DispatchPolicy::NumSpatialDimensions == 1 || + DispatchPolicy::NumSpatialDimensions == 2 || + DispatchPolicy::NumSpatialDimensions == 3) { + return cute::Stride, int64_t>{}; + } + // error dims assert + else { + static_assert(cutlass::detail::dependent_false, "Unsupported spatial dim count."); + } + } + else if constexpr (DispatchPolicy::ConvOp == conv::Operator::kDgrad) { + // Maps to modes ((q,n), K) + if constexpr (DispatchPolicy::NumSpatialDimensions == 1) { + return cute::Stride, + cute::Int<1>>{}; + } + // Maps to modes ((q,p,n), K) + else if constexpr (DispatchPolicy::NumSpatialDimensions == 2) { + return cute::Stride, + cute::Int<1>>{}; + } + // Maps to modes ((q,p,z,n), K) + else if constexpr (DispatchPolicy::NumSpatialDimensions == 3) { + return cute::Stride, + cute::Int<1>>{}; + } + // error dims assert + else { + static_assert(cutlass::detail::dependent_false, "Unsupported spatial dim count."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Unsupported ConvOp."); + } +} + +// Construct the stirde types for conv collectives based on the dispatch policy, strides 64b by default +template +constexpr auto +sm90_dispatch_policy_to_stride_B() { + if constexpr (DispatchPolicy::ConvOp == conv::Operator::kFprop) { + // Maps to modes (k, (C,s)) + if constexpr (DispatchPolicy::NumSpatialDimensions == 1) { + return cute::Stride, int64_t>>{}; + } + // Maps to modes (k, (C,s,r)) + else if constexpr (DispatchPolicy::NumSpatialDimensions == 2) { + return cute::Stride, int64_t, int64_t>>{}; + } + // Maps to modes (k, (C,s,r,t)) + else if constexpr (DispatchPolicy::NumSpatialDimensions == 3) { + return cute::Stride, int64_t, int64_t, int64_t>>{}; + } + // error dims assert + else { + static_assert(cutlass::detail::dependent_false, "Unsupported spatial dim count."); + } + } + else if constexpr (DispatchPolicy::ConvOp == conv::Operator::kWgrad) { + // Maps to modes (C, (w,n)) + if constexpr (DispatchPolicy::NumSpatialDimensions == 1) { + return cute::Stride, + cute::Stride>{}; + } + // Maps to modes (C, (w,h,n)) + else if constexpr (DispatchPolicy::NumSpatialDimensions == 2) { + return cute::Stride, + cute::Stride>{}; + } + // Maps to modes (C, (w,h,d,n)) + else if constexpr (DispatchPolicy::NumSpatialDimensions == 3) { + return cute::Stride, + cute::Stride>{}; + } + // error dims assert + else { + static_assert(cutlass::detail::dependent_false, "Unsupported spatial dim count."); + } + } + else if constexpr (DispatchPolicy::ConvOp == conv::Operator::kDgrad) { + // Maps to modes (C, (k,s)) + if constexpr (DispatchPolicy::NumSpatialDimensions == 1) { + return cute::Stride, cute::Stride>{}; + } + // Maps to modes (C, (k,s,r)) + else if constexpr (DispatchPolicy::NumSpatialDimensions == 2) { + return cute::Stride, cute::Stride>{}; + } + // Maps to modes (C, (k,s,r,t)) + else if constexpr (DispatchPolicy::NumSpatialDimensions == 3) { + return cute::Stride, cute::Stride>{}; + } + // error dims assert + else { + static_assert(cutlass::detail::dependent_false, "Unsupported spatial dim count."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Unsupported ConvOp."); + } +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Compute the lower/near corner, returning it as a cute::array in [W,H,D] order +template +CUTLASS_HOST_DEVICE +constexpr auto +compute_lower_corner_whd(ConvProblemShape const& problem_shape) { + using cute::for_each; + using cute::make_seq; + + cute::array lower{}; + if constexpr (ConvOp == conv::Operator::kFprop || + ConvOp == conv::Operator::kWgrad) { + for_each(make_seq{}, [&](auto i) { + lower[NumSpatialDimensions-1-i] = -1 * problem_shape.lower_padding[i]; + }); + } + else if constexpr (ConvOp == conv::Operator::kDgrad) { + for_each(make_seq{}, [&](auto i) { + lower[NumSpatialDimensions-1-i] = problem_shape.lower_padding[i] - + (problem_shape.shape_B[i+1] - 1) * problem_shape.dilation[i]; + }); + } + return lower; +} + +// Computes the upper/far corner, returning it as a cute::array in [W,H,D] order +template +CUTLASS_HOST_DEVICE +constexpr auto +compute_upper_corner_whd(ConvProblemShape const& problem_shape) { + using cute::for_each; + using cute::make_seq; + + cute::array upper{}; + if constexpr (ConvOp == conv::Operator::kFprop) { + for_each(make_seq{}, [&](auto i) { + upper[NumSpatialDimensions-1-i] = problem_shape.upper_padding[i] - + (problem_shape.shape_B[i+1] - 1) * problem_shape.dilation[i]; + }); + } + else if constexpr (ConvOp == conv::Operator::kWgrad) { + for_each(make_seq{}, [&](auto i) { + upper[NumSpatialDimensions-1-i] = problem_shape.upper_padding[i] - + (problem_shape.shape_C[i+1] - 1) * problem_shape.dilation[i]; + }); + } + else if constexpr (ConvOp == conv::Operator::kDgrad) { + for_each(make_seq{}, [&](auto i) { + upper[NumSpatialDimensions-1-i] = problem_shape.lower_padding[i] - + (problem_shape.shape_B[i+1] - 1) * problem_shape.dilation[i] + problem_shape.shape_C[i+1] - problem_shape.shape_A[i+1]; + }); + } + return upper; +} + +// Compute the lower/near corner of (t,r,s), returning it as a cute::array in [S,R,T] order +template +CUTLASS_HOST_DEVICE +constexpr auto +compute_lower_srt(ConvProblemShape const& problem_shape) { + using cute::for_each; + using cute::make_seq; + + cute::array lower{}; + if constexpr (ConvOp == conv::Operator::kFprop || + ConvOp == conv::Operator::kWgrad) { + for_each(make_seq{}, [&](auto i) { + lower[NumSpatialDimensions-1-i] = 0; + }); + } + else if constexpr (ConvOp == conv::Operator::kDgrad) { + for_each(make_seq{}, [&](auto i) { + lower[NumSpatialDimensions-1-i] = (problem_shape.shape_B[i+1] - 1) * problem_shape.dilation[i]; + }); + } + return lower; +} + +template struct is_im2col_load { static constexpr bool value = false; }; +template <> struct is_im2col_load { static constexpr bool value = true; }; +template <> struct is_im2col_load { static constexpr bool value = true; }; +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::conv::collective::detail diff --git a/include/cutlass/conv/collective/sm90_implicit_gemm_gmma_ss_warpspecialized.hpp b/include/cutlass/conv/collective/sm90_implicit_gemm_gmma_ss_warpspecialized.hpp new file mode 100644 index 0000000000..0e5d898d0e --- /dev/null +++ b/include/cutlass/conv/collective/sm90_implicit_gemm_gmma_ss_warpspecialized.hpp @@ -0,0 +1,753 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" + +#include "cute/tensor_predicate.hpp" +#include "cute/arch/cluster_sm90.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cute/atom/mma_atom.hpp" +#include "cute/atom/copy_traits_sm90_im2col.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" +#include "cute/algorithm/functional.hpp" +#include "cute/algorithm/gemm.hpp" + +#include "cutlass/conv/detail.hpp" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/dispatch_policy.hpp" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/util/packed_stride.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::conv::collective { +using namespace cute; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + conv::Operator ConvOp, + int Stages, + int NumSpatialDims, + class ClusterShape, + class KernelSchedule, + int PipelineAsyncMmaStages, + class TileShape_, + class ElementA_, + class ElementB_, + class TiledMma_, + class TileTraitsA_, + class TileTraitsB_> +struct CollectiveConv< + MainloopSm90TmaGmmaWarpSpecializedImplicitGemm< + ConvOp, Stages, NumSpatialDims, ClusterShape, KernelSchedule, PipelineAsyncMmaStages>, + TileShape_, + ElementA_, + ElementB_, + TiledMma_, + TileTraitsA_, + TileTraitsB_> +{ + // + // Type Aliases + // + using DispatchPolicy = MainloopSm90TmaGmmaWarpSpecializedImplicitGemm< + ConvOp, Stages, NumSpatialDims, ClusterShape, KernelSchedule, PipelineAsyncMmaStages>; + using TileShape = TileShape_; + using ElementA = ElementA_; + using ElementB = ElementB_; + using TiledMma = TiledMma_; + using ElementAccumulator = typename TiledMma::ValTypeC; + using GmemTiledCopyA = typename TileTraitsA_::GmemTiledCopy; + using GmemTiledCopyB = typename TileTraitsB_::GmemTiledCopy; + using SmemLayoutA = typename TileTraitsA_::SmemLayout; + using SmemLayoutB = typename TileTraitsB_::SmemLayout; + using ArchTag = typename DispatchPolicy::ArchTag; + static constexpr int NumSpatialDimensions = DispatchPolicy::NumSpatialDimensions; + static constexpr int NumTensorDimensions = NumSpatialDimensions + 2; + // Deduce the kernel-facing stride tuple types based on the dispatch policy + // (which is a function of the number of spatial dimensions, the algorithm, etc.) + using StrideA = decltype(detail::sm90_dispatch_policy_to_stride_A()); + using StrideB = decltype(detail::sm90_dispatch_policy_to_stride_B()); + + using MainloopPipeline = cutlass::PipelineTmaAsync; + + using PipelineParams = typename MainloopPipeline::Params; + using PipelineState = typename cutlass::PipelineState; + + using ProblemShape = ConvProblemShape; + + // TODO: move pipeline mode tiling into the collective setup phase instead + static_assert(rank(SmemLayoutA{}) == 3, "SmemLayout must be rank 3 (M/N, K, PIPE)"); + static_assert((size<0>(TileShape{}) == size<0>(SmemLayoutA{})), "SmemLayout must be compatible with the tile shape."); + static_assert((size<2>(TileShape{}) == size<1>(SmemLayoutA{})), "SmemLayout must be compatible with the tile shape."); + + static_assert(rank(SmemLayoutB{}) == 3, "SmemLayout must be rank 3 (M/N, K, PIPE)"); + static_assert((size<1>(TileShape{}) == size<0>(SmemLayoutB{})), "SmemLayout must be compatible with the tile shape."); + static_assert((size<2>(TileShape{}) == size<1>(SmemLayoutB{})), "SmemLayout must be compatible with the tile shape."); + + static_assert(DispatchPolicy::Stages >= 2, "Specialization requires Stages set to value 1 or more."); + static_assert(cute::is_base_of::value && + cute::is_base_of::value, + "MMA atom must source both A and B operand from smem_desc for this mainloop."); + + // The tma load mode of wgrad is tiled for tensor A and im2col for tensor B while the tma load mode of fprop and dgrad + // kernel is im2col for tensor A and tiled for tensor B. + static_assert((ConvOp == conv::Operator::kWgrad + && (cute::is_same_v || cute::is_same_v)) + || (ConvOp != conv::Operator::kWgrad + && (cute::is_same_v || cute::is_same_v)), + "GmemTiledCopyA - invalid SM90 TMA copy atom specified."); + static_assert((ConvOp == conv::Operator::kWgrad + && (cute::is_same_v || cute::is_same_v)) + || (ConvOp != conv::Operator::kWgrad + && (cute::is_same_v || cute::is_same_v)), + "GmemTiledCopyB - invalid SM90 TMA copy atom specified."); + + static constexpr bool is_im2col_A = detail::is_im2col_load::value; + static constexpr bool is_im2col_B = detail::is_im2col_load::value; + + // TMA converts f32 input to tf32 when copying from GMEM to SMEM + // For all other types, cast to size equivalent uint type to avoid any rounding by TMA. + static constexpr bool ConvertF32toTF32A = cute::is_same_v; + static constexpr bool ConvertF32toTF32B = cute::is_same_v; + using InternalElementA = cute::conditional_t>>; + using InternalElementB = cute::conditional_t>>; + + struct SharedStorage + { + struct TensorStorage : cute::aligned_struct<128, _0> { + cute::array_aligned> smem_A; + cute::array_aligned> smem_B; + } tensors; + + using PipelineStorage = typename MainloopPipeline::SharedStorage; + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + static constexpr int K_PIPE_MAX = DispatchPolicy::Stages; + static constexpr int K_PIPE_MMAS = DispatchPolicy::PipelineAsyncMmaStages; + static constexpr uint32_t TmaTransactionBytes = + (size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(sizeof(InternalElementA)))+ + (size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(sizeof(InternalElementB))); + + // Host side kernel arguments + struct Arguments { + ElementA const* ptr_A{nullptr}; + ElementB const* ptr_B{nullptr}; + }; + +private: + // Note that for fprop and dgrad kernel, the tma load mode is im2col for tensor A and tiled for + // tensor B while for wgrad kernel, the tma load mode is tiled for tensor A and im2col for tensor + // B since operand A, B is swapped. + // Get tma_load_a instantce. + template + static constexpr auto + get_tma_load_a_instance(TensorA const& tensor_a, ProblemShape const& problem_shape) { + if constexpr (is_im2col_A) { + // compute the upper and lower corners based on the conv padding + auto lower_corner_whd = detail::compute_lower_corner_whd(problem_shape); + auto upper_corner_whd = detail::compute_upper_corner_whd(problem_shape); + auto lower_srt = detail::compute_lower_srt(problem_shape); + + // The calculation of gbasis strides for dgrad kernel needs perform negate for dilation values. + cute::array stride_srt{}; + for (int i = 0; i < NumSpatialDimensions; ++i) { + stride_srt[i] = ConvOp == conv::Operator::kDgrad ? + -problem_shape.dilation[NumSpatialDimensions-1-i] : + problem_shape.dilation[NumSpatialDimensions-1-i]; + } + + return make_im2col_tma_copy( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_0{}), + product_each(shape(SmemLayoutA{}(_,_,_0{}))), + size<1>(ClusterShape{}), + shape(lower_corner_whd), + shape(upper_corner_whd), + cute::reverse(shape(problem_shape.lower_padding)), + cute::reverse(shape(problem_shape.upper_padding)), + cute::reverse(shape(problem_shape.traversal_stride)), + shape(lower_srt), + shape(stride_srt)); + } + // TMA tiled mode for tensor A in wgrad kernel. + else { + return make_tma_copy( + GmemTiledCopyA{}, + tensor_a, + SmemLayoutA{}(_,_,_0{}), + make_shape(shape<0>(TileShape{}), shape<2>(TileShape{})), + size<1>(ClusterShape{})); + } + } + + // Get tma_load_b instantce. + template + static constexpr auto + get_tma_load_b_instance(TensorB const& tensor_b, ProblemShape const& problem_shape) { + // TMA im2col mode for tensor B in wgrad kernel. + if constexpr (is_im2col_B) { + // compute the upper and lower corners based on the conv padding + auto lower_corner_whd = detail::compute_lower_corner_whd(problem_shape); + auto upper_corner_whd = detail::compute_upper_corner_whd(problem_shape); + auto lower_srt = detail::compute_lower_srt(problem_shape); + + return make_im2col_tma_copy( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_0{}), + product_each(shape(SmemLayoutB{}(_,_,_0{}))), + size<0>(ClusterShape{}), + shape(lower_corner_whd), + shape(upper_corner_whd), + cute::reverse(shape(problem_shape.lower_padding)), + cute::reverse(shape(problem_shape.upper_padding)), + cute::reverse(shape(problem_shape.traversal_stride)), + shape(lower_srt), + cute::reverse(shape(problem_shape.dilation))); + } + else { + return make_tma_copy( + GmemTiledCopyB{}, + tensor_b, + SmemLayoutB{}(_,_,_0{}), + make_shape(shape<1>(TileShape{}), shape<2>(TileShape{})), + size<0>(ClusterShape{})); + } + } + +public: + + // Performs im2col transformations on the input of type ConvProblemShape + static constexpr auto + get_problem_shape_MNKL(ProblemShape const& problem_shape) { + + if constexpr (is_im2col_A || is_im2col_B) { + // transformation + im2col linearization + return cutlass::conv::detail::get_linearized_problem_shape_MNKL(problem_shape); + } + else { + // transformation + return cutlass::conv::detail::get_transformed_problem_shape_MNKL(problem_shape); + } + } + + // Device side kernel params + struct Params { + using _Submode = decltype(take<0,NumTensorDimensions-1>(typename ProblemShape::TensorExtent{})); + + // Assumption: StrideA is congruent with Problem_MK + // Select TMA load type according to convolution operator. + using TensorShapeA = cute::conditional_t; + + using TensorShapeB = cute::conditional_t; + + using TMA_A = decltype(get_tma_load_a_instance( + make_tensor( + make_gmem_ptr(static_cast(nullptr)), + make_layout(TensorShapeA{}, StrideA{})), + ConvProblemShape{})); + + using TMA_B = decltype(get_tma_load_b_instance( + make_tensor( + make_gmem_ptr(static_cast(nullptr)), + make_layout(TensorShapeB{}, StrideB{})), + ConvProblemShape{})); + + // Members + TMA_A tma_load_a; + TMA_B tma_load_b; + uint32_t tma_transaction_bytes = TmaTransactionBytes; + }; + + // + // Methods + // + + // Lowers the host side user facing arguments to the kernel facing lauch params + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + (void) workspace; + // from the flat problem shape arrays of ConvProblemShape, create a rank-3 MNK problem shape tuple + // tma desc creation depends on the original untransformed domain. + + // A extents. + auto shape_A_orig = problem_shape.get_shape_A(); + // B extents. + auto shape_B_orig = problem_shape.get_shape_B(); + + // Fill inferred cute strides from flat stride arrays + auto dA = make_cute_packed_stride(StrideA{}, problem_shape.stride_A, ConvOp); + auto dB = make_cute_packed_stride(StrideB{}, problem_shape.stride_B, ConvOp); + + auto ptr_A = reinterpret_cast(args.ptr_A); + auto ptr_B = reinterpret_cast(args.ptr_B); + + Tensor tensor_a = make_tensor(make_gmem_ptr(ptr_A), make_layout(shape_A_orig, dA)); + Tensor tensor_b = make_tensor(make_gmem_ptr(ptr_B), make_layout(shape_B_orig, dB)); + + auto tma_load_a = get_tma_load_a_instance(tensor_a, problem_shape); + auto tma_load_b = get_tma_load_b_instance(tensor_b, problem_shape); + + return { + tma_load_a, + tma_load_b, + TmaTransactionBytes + }; + } + + template + static bool + can_implement( + ProblemShape const& problem_shape, + Arguments const& args) { + // Activation and Filter channel mode extents much match + bool implementable = true; + // channel mode is major + implementable &= problem_shape.stride_A[NumTensorDimensions-1] == 1; + implementable &= problem_shape.stride_B[NumTensorDimensions-1] == 1; + + constexpr int tma_alignment_bits = 128; + // A extents. + auto shape_A_orig = problem_shape.get_shape_A(); + // B extents. + auto shape_B_orig = problem_shape.get_shape_B(); + constexpr int min_tma_aligned_elements_A = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(shape_A_orig, StrideA{}); + constexpr int min_tma_aligned_elements_B = tma_alignment_bits / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(shape_B_orig, StrideB{}); + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + return false; + } + + // Check valid padding values for TMA_LOAD_IM2COL + constexpr int padding_limit = (ProblemShape::RankS == 1) ? 65536 : (ProblemShape::RankS == 2 ? 256 : 16); + for (int i = 0; i < problem_shape.RankS; ++i) { + implementable = implementable && problem_shape.lower_padding[i] <= padding_limit && problem_shape.lower_padding[i] >= 0; + implementable = implementable && problem_shape.upper_padding[i] <= padding_limit && problem_shape.upper_padding[i] >= 0; + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Padding values don't meet requirements for TMA LOAD IM2COL.\n"); + return false; + } + + if (is_im2col_A || is_im2col_B) { + // Check valid corner values for TMA_LOAD_IM2COL, signed int ranging from [-corner_limit, corner_limit - 1] + constexpr int32_t corner_limit = 1 << (16 / NumSpatialDimensions - 1); + auto lower_corner_whd = detail::compute_lower_corner_whd(problem_shape); + for (int i = 0; i < problem_shape.RankS; ++i) { + implementable = implementable && lower_corner_whd[i] >= -corner_limit && lower_corner_whd[i] <= (corner_limit - 1); + } + auto upper_corner_whd = detail::compute_upper_corner_whd(problem_shape); + for (int i = 0; i < problem_shape.RankS; ++i) { + implementable = implementable && upper_corner_whd[i] >= -corner_limit && upper_corner_whd[i] <= (corner_limit - 1); + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Padding values don't meet requirements for TMA LOAD IM2COL.\n"); + return false; + } + } + + // Wgrad kernels don't support non-packed output strides, non-packed tensor A stride (linearized) + if constexpr (ConvOp == conv::Operator::kWgrad) { +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + std::ostringstream os; +#endif + const auto & input_shape = problem_shape.shape_A; + const auto & input_stride = problem_shape.stride_A; + + implementable &= input_stride[ProblemShape::RankT - 1] == 1; + int input_shape_size = 1; + for (int i = ProblemShape::RankT - 2; i >= 0; --i) { + input_shape_size *= input_shape[i + 1]; + implementable &= input_stride[i] == input_shape_size; +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + if (input_stride[i] != input_shape_size) { + os << "\n *** input_stride[" << i << "] = " << input_stride[i] << " != input_shape_size = " << input_shape_size << " ***"; + } +#endif + } + + if (!implementable) { +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + os << "\n input_shape_size: " << input_shape_size + << "\n input_shape: " << input_shape + << "\n input_stride: " << input_stride + << "\n"; +#endif + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Wgrad kernels don't support non-packed input strides.\n"); +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST(os.str()); +#endif + return false; + } + + const auto & output_shape = problem_shape.shape_C; + const auto & output_stride = problem_shape.stride_C; + + implementable &= output_stride[ProblemShape::RankT - 1] == 1; + int output_shape_size = 1; + for (int i = ProblemShape::RankT - 2; i >= 0; --i) { + output_shape_size *= output_shape[i + 1]; + implementable &= output_stride[i] == output_shape_size; +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + if (output_stride[i] != output_shape_size) { + os << "\n *** output_stride[" << i << "] = " << output_stride[i] << " != output_shape_size = " << output_shape_size << " ***"; + } +#endif + } + + if (!implementable) { +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + os << "\n output_shape_size: " << input_shape_size + << "\n output_shape: " << input_shape + << "\n output_stride: " << input_stride + << "\n"; +#endif + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Wgrad kernels don't support non-packed output strides.\n"); +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + CUTLASS_TRACE_HOST(os.str()); +#endif + return false; + } + } + + // Conv kernels only support cross correlation mode currently. + implementable &= problem_shape.mode == cutlass::conv::Mode::kCrossCorrelation; + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Conv kernels only support cross correlation mode currently.\n"); + return false; + } + + if (problem_shape.groups > 1) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: This kernel does not support conv groups > 1.\n"); + return false; + } + + return true; + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void prefetch_tma_descriptors(Params const& mainloop_params) { + cute::prefetch_tma_descriptor(mainloop_params.tma_load_a.get_tma_descriptor()); + cute::prefetch_tma_descriptor(mainloop_params.tma_load_b.get_tma_descriptor()); + } + + /// Set up the data needed by this collective for load and mma. + /// Returns a tuple of tensors. The collective and the kernel layer have the contract + /// Returned tuple must contain at least two elements, with the first two elements being: + /// gA_mk - The tma tensor, A after a local tile so it has shape (BLK_M,BLK_K,m,k) + /// gB_nk - The tma tensor, B after a local tile so it has shape (BLK_N,BLK_K,n,k) + /// The rest of the tensors can be specified as needed by this collective. + /// The dimensions of gA_mk and gA_nk do not contain L to maintain consistency with + /// StrideA and StrideB set up for TMA + template + CUTLASS_DEVICE auto + load_init(ProblemShapeMNKL const& problem_shape_MNKL, Params const& mainloop_params){ + //load_init(ProblemShapeMNKL const& problem_shape_MNKL, Params const& mainloop_params) const { + using X = Underscore; + // Separate out problem shape for convenience + auto [M, N, K, L] = problem_shape_MNKL; + + // TMA requires special handling of strides to deal with coord codomain mapping + // Represent the full tensors -- get these from TMA + Tensor mA_mk = mainloop_params.tma_load_a.get_tma_tensor(make_shape(M,K)); // (m,k) + Tensor mB_nk = mainloop_params.tma_load_b.get_tma_tensor(make_shape(N,K)); // (n,k) + + // Make tiled views, defer the slice + Tensor gA_mk = local_tile(mA_mk, TileShape{}, make_coord(_,_,_), Step<_1, X,_1>{}); // (BLK_M,BLK_K,m,k) + Tensor gB_nk = local_tile(mB_nk, TileShape{}, make_coord(_,_,_), Step< X,_1,_1>{}); // (BLK_N,BLK_K,n,k) + + return cute::make_tuple(gA_mk, gB_nk); + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Producer Perspective + template < + class TensorA, class TensorB, + class KTileIterator, class BlockCoord + > + CUTLASS_DEVICE void + load( + Params const& mainloop_params, + MainloopPipeline pipeline, + PipelineState smem_pipe_producer_state, + cute::tuple const& load_inputs, + BlockCoord const& blk_coord, + KTileIterator k_tile_iter, int k_tile_count, + int thread_idx, + uint32_t block_rank_in_cluster, + TensorStorage& shared_tensors) { + + int lane_predicate = cute::elect_one_sync(); + if (lane_predicate) { + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // + // Prepare the TMA loads for A and B + // + constexpr uint32_t cluster_shape_x = get<0>(ClusterShape()); + + uint2 cluster_local_block_id = {block_rank_in_cluster % cluster_shape_x, block_rank_in_cluster / cluster_shape_x}; + auto block_tma_a = mainloop_params.tma_load_a.get_slice(cluster_local_block_id.y); + auto block_tma_b = mainloop_params.tma_load_b.get_slice(cluster_local_block_id.x); + + auto [gA_mk, gB_nk] = load_inputs; + + // Partition the inputs based on the current block coordinates. + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord; + + Tensor gA = gA_mk(_,_,m_coord,_); // (BLK_M,BLK_K,k) + Tensor gB = gB_nk(_,_,n_coord,_); // (BLK_N,BLK_K,k) + + // Applies the mapping from block_tma_a + Tensor tAgA = block_tma_a.partition_S(gA); // (TMA,TMA_M,TMA_K,k) + Tensor tAsA = block_tma_a.partition_D(sA); // (TMA,TMA_M,TMA_K,PIPE) + + Tensor tBgB = block_tma_b.partition_S(gB); // (TMA,TMA_N,TMA_K,k) + Tensor tBsB = block_tma_b.partition_D(sB); // (TMA,TMA_N,TMA_K,PIPE) + + uint16_t mcast_mask_a = 0; + uint16_t mcast_mask_b = 0; + + // Issue TmaLoads + // Maps the tile -> block, value + if constexpr (cute::is_same_v || + cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int n = 0; n < size<1>(block_layout); ++n) { + mcast_mask_a |= (uint16_t(1) << block_layout(cluster_local_block_id.x,n,Int<0>{})); + } + } + + if constexpr (cute::is_same_v || + cute::is_same_v) { + auto block_layout = Layout{}; // (m,n) -> block_id + for (int m = 0; m < size<0>(block_layout); ++m) { + mcast_mask_b |= (uint16_t(1) << block_layout(m,cluster_local_block_id.y,Int<0>{})); + } + } + + // Mainloop + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + // LOCK smem_pipe_producer_state for _writing_ + pipeline.producer_acquire(smem_pipe_producer_state); + + // + // Copy gmem to smem for *k_tile_iter + // + + using BarrierType = typename MainloopPipeline::ProducerBarrierType; + BarrierType* tma_barrier = pipeline.producer_get_barrier(smem_pipe_producer_state); + + int write_stage = smem_pipe_producer_state.index(); + + copy(mainloop_params.tma_load_a.with(*tma_barrier, mcast_mask_a), tAgA(_,_,_,*k_tile_iter), tAsA(_,_,_,write_stage)); + copy(mainloop_params.tma_load_b.with(*tma_barrier, mcast_mask_b), tBgB(_,_,_,*k_tile_iter), tBsB(_,_,_,write_stage)); + ++k_tile_iter; + + // Advance smem_pipe_producer_state + ++smem_pipe_producer_state; + } + } + } + + /// Perform a Producer Epilogue to prevent early exit of blocks in a Cluster + CUTLASS_DEVICE void + load_tail(MainloopPipeline pipeline, PipelineState smem_pipe_producer_state) { + int lane_predicate = cute::elect_one_sync(); + + // Issue the epilogue waits + if (lane_predicate) { + /* This helps avoid early exit of blocks in Cluster + * Waits for all stages to either be released (all + * Consumer UNLOCKs), or if the stage was never used + * then would just be acquired since the phase was + * still inverted from make_producer_start_state + */ + pipeline.producer_tail(smem_pipe_producer_state); + } + } + + /// Perform a collective-scoped matrix multiply-accumulate + /// Consumer Perspective + template + CUTLASS_DEVICE void + mma(MainloopPipeline pipeline, + PipelineState smem_pipe_consumer_state, + FrgTensorC& accum, + int k_tile_count, + int thread_idx, + TensorStorage& shared_tensors, + Params const& mainloop_params) { + static_assert(is_rmem::value, "C tensor must be rmem resident."); + + Tensor sA = make_tensor(make_smem_ptr(shared_tensors.smem_A.data()), SmemLayoutA{}); // (BLK_M,BLK_K,PIPE) + Tensor sB = make_tensor(make_smem_ptr(shared_tensors.smem_B.data()), SmemLayoutB{}); // (BLK_N,BLK_K,PIPE) + + // + // Define C accumulators and A/B partitioning + // + + TiledMma tiled_mma; + auto thread_mma = tiled_mma.get_thread_slice(thread_idx); + + Tensor tCsA = thread_mma.partition_A(sA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCsB = thread_mma.partition_B(sB); // (MMA,MMA_N,MMA_K,PIPE) + + // Allocate "fragments/descriptors" + Tensor tCrA = thread_mma.make_fragment_A(tCsA); // (MMA,MMA_M,MMA_K,PIPE) + Tensor tCrB = thread_mma.make_fragment_B(tCsB); // (MMA,MMA_N,MMA_K,PIPE) + + CUTE_STATIC_ASSERT_V(size<1>(tCsA) == size<1>(accum)); // M + CUTE_STATIC_ASSERT_V(size<1>(tCsB) == size<2>(accum)); // N + CUTE_STATIC_ASSERT_V(size<2>(tCsA) == size<2>(tCsB)); // K + CUTE_STATIC_ASSERT_V(size<3>(tCsA) == size<3>(tCsB)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sA)); // PIPE + CUTE_STATIC_ASSERT_V(Int{} == size<2>(sB)); // PIPE + + // + // PIPELINED MAIN LOOP + // + static_assert((0 <= K_PIPE_MMAS) && (K_PIPE_MMAS < K_PIPE_MAX), + "ERROR : Incorrect number of MMAs in flight"); + + // We release buffers to producer warps(dma load) with some mmas in flight + PipelineState smem_pipe_release = smem_pipe_consumer_state; + + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + + tiled_mma.accumulate_ = GMMA::ScaleOut::Zero; + + warpgroup_fence_operand(accum); + CUTLASS_PRAGMA_UNROLL + for (int k_tile_prologue = prologue_mma_count; k_tile_prologue > 0; --k_tile_prologue) { + // WAIT on smem_pipe_consumer_state until its data are available (phase bit flips from rdPhaseBit value) + pipeline.consumer_wait(smem_pipe_consumer_state); + + int read_stage = smem_pipe_consumer_state.index(); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M,K) x (V,N,K) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + + warpgroup_commit_batch(); + + ++smem_pipe_consumer_state; + } + + warpgroup_fence_operand(accum); + // Mainloop GMMAs + k_tile_count -= prologue_mma_count; + + CUTLASS_PRAGMA_NO_UNROLL + for ( ; k_tile_count > 0; --k_tile_count) { + // WAIT on smem_pipe_consumer_state until its data are available (phase bit flips from rdPhaseBit value) + pipeline.consumer_wait(smem_pipe_consumer_state); + + // + // Compute on k_tile + // + + int read_stage = smem_pipe_consumer_state.index(); + warpgroup_fence_operand(accum); + warpgroup_arrive(); + // Unroll the K mode manually to set scale D to 1 + CUTLASS_PRAGMA_UNROLL + for (int k_block = 0; k_block < size<2>(tCrA); ++k_block) { + // (V,M) x (V,N) => (V,M,N) + cute::gemm(tiled_mma, tCrA(_,_,k_block,read_stage), tCrB(_,_,k_block,read_stage), accum); + tiled_mma.accumulate_ = GMMA::ScaleOut::One; + } + warpgroup_commit_batch(); + + /// Wait on the GMMA barrier for K_PIPE_MMAS (or fewer) outstanding to ensure smem_pipe_producer_state is consumed + warpgroup_wait(); + warpgroup_fence_operand(accum); + + // UNLOCK smem_pipe_release, done _computing_ on it + pipeline.consumer_release(smem_pipe_release); + + // Advance smem_pipe_consumer_state and smem_pipe_release + ++smem_pipe_consumer_state; + ++smem_pipe_release; + } + + warpgroup_fence_operand(accum); + } + + /// Perform a Consumer Epilogue to release all buffers + CUTLASS_DEVICE void + mma_tail(MainloopPipeline pipeline, PipelineState smem_pipe_release, int k_tile_count) { + // Prologue GMMAs + int prologue_mma_count = min(K_PIPE_MMAS, k_tile_count); + k_tile_count -= prologue_mma_count; + + smem_pipe_release.advance(k_tile_count); + + // Wait on all GMMAs to complete + warpgroup_wait<0>(); + + for (int count = 0; count < prologue_mma_count; ++count) { + pipeline.consumer_release(smem_pipe_release); // UNLOCK smem_pipe_release, done _computing_ on it + ++smem_pipe_release; + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::conv::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/conv2d_problem_size.h b/include/cutlass/conv/conv2d_problem_size.h index 7d0c86f478..d2e8952998 100644 --- a/include/cutlass/conv/conv2d_problem_size.h +++ b/include/cutlass/conv/conv2d_problem_size.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -35,7 +35,7 @@ activation (NHWC), filter (KRSC), output (NPQK), - pading (pad_h, pad_w), + pading (pad_h, pad_w), stride (stride_h, stride_w), dilation (dilation_h, dilation_w). @@ -47,17 +47,10 @@ #pragma once - -#if defined(__CUDACC_RTC__) -#include -#else -#include -#endif - #include "cutlass/cutlass.h" #include "cutlass/tensor_coord.h" #include "cutlass/fast_math.h" -#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/gemm_enumerated_types.h" #include "cutlass/matrix_coord.h" #include "cutlass/conv/convolution.h" #include "cutlass/functional.h" @@ -87,7 +80,7 @@ struct Conv2dProblemSize { public: CUTLASS_HOST_DEVICE - Conv2dProblemSize(): + Conv2dProblemSize(): N(0), H(0), W(0), C(0), P(0), Q(0), K(0), R(0), S(0), pad_h(0), pad_w(0), stride_h(1), stride_w(1), dilation_h(1), dilation_w(1), mode(Mode::kConvolution), split_k_slices(1), groups(1) { } @@ -107,7 +100,7 @@ struct Conv2dProblemSize { Mode mode ): N(N), H(H), W(W), C(C), P(P), Q(Q), K(K), R(R), S(S), - pad_h(R / 2), pad_w(S / 2), stride_h(1), stride_w(1), dilation_h(1), dilation_w(1), + pad_h(R / 2), pad_w(S / 2), stride_h(1), stride_w(1), dilation_h(1), dilation_w(1), mode(mode), split_k_slices(1), groups (1) { } /// Constructor @@ -131,9 +124,9 @@ struct Conv2dProblemSize { Mode mode, int split_k_slices = 1, int groups = 1 - ): - N(N), H(H), W(W), C(C), K(K), R(R), S(S), P(P), Q(Q), - pad_h(pad_h), pad_w(pad_w), stride_h(stride_h), stride_w(stride_w), + ): + N(N), H(H), W(W), C(C), P(P), Q(Q), K(K), R(R), S(S), + pad_h(pad_h), pad_w(pad_w), stride_h(stride_h), stride_w(stride_w), dilation_h(dilation_h), dilation_w(dilation_w), mode(mode), split_k_slices(split_k_slices), groups (groups) { } @@ -152,11 +145,11 @@ struct Conv2dProblemSize { int groups = 1 ): N(input_size.n()), H(input_size.h()), W(input_size.w()), C(input_size.c()), + P(output_size.h()), Q(output_size.w()), K(filter_size.n()), R(filter_size.h()), S(filter_size.w()), - pad_h(padding[0]), pad_w(padding[2]), - stride_h(stride.row()), stride_w(stride.column()), + pad_h(padding[0]), pad_w(padding[2]), + stride_h(stride.row()), stride_w(stride.column()), dilation_h(dilation.row()), dilation_w(dilation.column()), - P(output_size.h()), Q(output_size.w()), mode(mode), split_k_slices(split_k_slices), groups(groups) {} /// Constructs convolution problem size from cutlass Tensor4DCoord and MatrixCoord @@ -165,7 +158,7 @@ struct Conv2dProblemSize { Conv2dProblemSize( cutlass::Tensor4DCoord input_size, // NHWC cutlass::Tensor4DCoord filter_size, // KRSC - cutlass::Tensor4DCoord padding, // pad_h, _, pad_w, _ + cutlass::Tensor4DCoord padding, // pad_h, upper_pad_h, pad_w, upper_pad_w cutlass::MatrixCoord stride, // stride_h, stride_w cutlass::MatrixCoord dilation, // dilation_h, dilation_w cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation, @@ -175,12 +168,12 @@ struct Conv2dProblemSize { N(input_size.n()), H(input_size.h()), W(input_size.w()), C(input_size.c()), K(filter_size.n()), R(filter_size.h()), S(filter_size.w()), pad_h(padding[0]), pad_w(padding[2]), - stride_h(stride.row()), stride_w(stride.column()), + stride_h(stride.row()), stride_w(stride.column()), dilation_h(dilation.row()), dilation_w(dilation.column()), mode(mode), split_k_slices(split_k_slices), groups(groups) { // set output P and Q - P = ((H + pad_h * 2 - R * dilation_h) / stride_h) + 1; - Q = ((W + pad_w * 2 - S * dilation_w) / stride_w) + 1; + P = ((H + pad_h + padding[1] - R * dilation_h) / stride_h) + 1; + Q = ((W + pad_w + padding[3] - S * dilation_w) / stride_w) + 1; } /// Constructs convolution problem size from cutlass Tensor4DCoord and MatrixCoord @@ -195,9 +188,9 @@ struct Conv2dProblemSize { int groups = 1 ): N(input_size.n()), H(input_size.h()), W(input_size.w()), C(input_size.c()), + P(output_size.h()), Q(output_size.w()), K(filter_size.n()), R(filter_size.h()), S(filter_size.w()), - P(output_size.h()), Q(output_size.w()), - pad_h(R / 2), pad_w(S / 2), stride_h(1), stride_w(1), + pad_h(R / 2), pad_w(S / 2), stride_h(1), stride_w(1), dilation_h(1), dilation_w(1), mode(mode), split_k_slices(split_k_slices), groups(groups) {} @@ -221,12 +214,12 @@ struct Conv2dProblemSize { CUTLASS_HOST_DEVICE bool operator==(Conv2dProblemSize const &conv) const { return ( - (N == conv.N) && (W == conv.H) && (W == conv.W) && (C == conv.C) && + (N == conv.N) && (H == conv.H) && (W == conv.W) && (C == conv.C) && (K == conv.K) && (R == conv.R) && (S == conv.S) && (P == conv.P) && (Q == conv.Q) && (pad_h == conv.pad_h) && (pad_w == conv.pad_w) && (stride_h == conv.stride_h) && (stride_w == conv.stride_w) && - (dilation_h == conv.dilation_h) && (dilation_h == conv.dilation_h) + (dilation_h == conv.dilation_h) && (dilation_w == conv.dilation_w) ); } @@ -245,9 +238,10 @@ struct Conv2dProblemSize { /// Returns filter extent as Tensor4DCoord CUTLASS_HOST_DEVICE - cutlass::Tensor4DCoord filter_extent() const { + cutlass::Tensor4DCoord filter_extent(bool is_deconv = false) const { - return cutlass::Tensor4DCoord ({K, R, S, C}); + return is_deconv ? cutlass::Tensor4DCoord ({C, R, S, K / groups}) + : cutlass::Tensor4DCoord ({K, R, S, C / groups}); } /// Returns output extent as Tensor4DCoord @@ -268,7 +262,7 @@ struct Conv2dProblemSize { CUTLASS_HOST_DEVICE int64_t filter_size() const { - return (K * R * S * C); + return (K * R * S * C / groups); } /// Returns output size in number of elements @@ -278,7 +272,7 @@ struct Conv2dProblemSize { return (N * P * Q * K); } - /// Returns output extent as Tensor4DCoord + /// Returns padding as Tensor4DCoord CUTLASS_HOST_DEVICE cutlass::Tensor4DCoord padding() const { @@ -336,8 +330,9 @@ cutlass::gemm::GemmCoord implicit_gemm_problem_size( return gemm::GemmCoord( problem_size.N * problem_size.P * problem_size.Q, problem_size.K, - problem_size.R * problem_size.S * problem_size.C + problem_size.R * problem_size.S * problem_size.C / problem_size.groups ); + case Operator::kDeconv: case Operator::kDgrad: return gemm::GemmCoord( problem_size.N * problem_size.H * problem_size.W, @@ -362,61 +357,160 @@ int implicit_gemm_k_iterations( Operator conv_operator, int threadblock_K, Conv2dProblemSize const &problem_size, - IteratorAlgorithm algorithm = IteratorAlgorithm::kAnalytic) { + IteratorAlgorithm algorithm = IteratorAlgorithm::kAnalytic, + GroupMode group_mode = GroupMode::kNone, + int threadblock_N = 0) { int iterations = 0; - if (algorithm == IteratorAlgorithm::kFixedChannels) { + if (group_mode == GroupMode::kNone) { - int positions_per_iteration = threadblock_K / problem_size.C; - switch (conv_operator) { - case Operator::kFprop: - iterations = (problem_size.R * problem_size.S + positions_per_iteration - 1 ) / positions_per_iteration; - break; + if (algorithm == IteratorAlgorithm::kFixedChannels) { - default: - break; + int positions_per_iteration = threadblock_K / problem_size.C; + switch (conv_operator) { + case Operator::kFprop: + iterations = (problem_size.R * problem_size.S + positions_per_iteration - 1 ) / positions_per_iteration; + break; + + default: + break; + } } - } - else if (algorithm == IteratorAlgorithm::kFewChannels) { + else if (algorithm == IteratorAlgorithm::kFewChannels) { - switch (conv_operator) { - case Operator::kFprop: - iterations = (problem_size.R * problem_size.S * problem_size.C + threadblock_K - 1 ) / threadblock_K; - break; + switch (conv_operator) { + case Operator::kFprop: + iterations = (problem_size.R * problem_size.S * problem_size.C + threadblock_K - 1 ) / threadblock_K; + break; - default: - break; + default: + break; + } + } + else { + int elements_per_split_k_slice = 0; + + switch (conv_operator) { + case Operator::kFprop: + elements_per_split_k_slice = (problem_size.C + problem_size.split_k_slices - 1) / problem_size.split_k_slices; + iterations = problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K); + break; + + case Operator::kDeconv: + case Operator::kDgrad: + elements_per_split_k_slice = (problem_size.K + problem_size.split_k_slices - 1) / problem_size.split_k_slices; + iterations = problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K); + break; + + case Operator::kWgrad: + elements_per_split_k_slice = (problem_size.N * problem_size.P * problem_size.Q + problem_size.split_k_slices - 1) / problem_size.split_k_slices; + iterations = (elements_per_split_k_slice + threadblock_K - 1) / threadblock_K; + break; + + default: + break; + } } - } - else { - int elements_per_split_k_slice = 0; - switch (conv_operator) { - case Operator::kFprop: - elements_per_split_k_slice = (problem_size.C + problem_size.split_k_slices - 1) / problem_size.split_k_slices; - iterations = problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K); - break; - - case Operator::kDgrad: - elements_per_split_k_slice = (problem_size.K + problem_size.split_k_slices - 1) / problem_size.split_k_slices; - iterations = problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K); - break; - - case Operator::kWgrad: - elements_per_split_k_slice = (problem_size.N * problem_size.P * problem_size.Q + problem_size.split_k_slices - 1) / problem_size.split_k_slices; - iterations = (elements_per_split_k_slice + threadblock_K - 1) / threadblock_K; - break; - - default: - break; + } else if (group_mode == GroupMode::kDepthwise) { + int channels_per_cta = threadblock_N; + + if (algorithm == IteratorAlgorithm::kAnalytic) { + switch (conv_operator) { + case Operator::kFprop: + iterations = problem_size.R * problem_size.S * + ((channels_per_cta + threadblock_K - 1) / threadblock_K); + break; + + default: + break; + } + } + } else { // Group conv + + int channels_per_group = problem_size.C / problem_size.groups; + int k_per_group = problem_size.K / problem_size.groups; + + if (algorithm == IteratorAlgorithm::kAnalytic) { + switch (conv_operator) { + case Operator::kFprop: + iterations = problem_size.R * problem_size.S * ((channels_per_group + threadblock_K - 1) / threadblock_K); + // In group conv, if k_per_group < threadblock_N, one Threadblock will calculate multiple groups + if (problem_size.groups != 1) { + if (k_per_group < threadblock_N) { + iterations *= threadblock_N / k_per_group; + } + } + break; + + default: + break; + } + } else if (algorithm == IteratorAlgorithm::kOptimized) { + // Current optimized iterator only support GroupMode::kSingleGroup + if (group_mode == GroupMode::kSingleGroup) { + switch (conv_operator) { + case Operator::kFprop: + iterations = problem_size.R * problem_size.S * ((channels_per_group + threadblock_K - 1) / threadblock_K); + break; + + default: + break; + } + } } + } return iterations; } +template +CUTLASS_HOST_DEVICE +int depthwise_gemm_k_iterations( + Operator conv_operator, + int threadblock_K, + Conv2dProblemSize const &problem_size, + IteratorAlgorithm algorithm = IteratorAlgorithm::kAnalytic, + GroupMode group_mode = GroupMode::kNone, + int threadblock_N = 0) { + + int n = problem_size.N; + int p = (problem_size.P + Output_P - 1) / Output_P; + int q = (problem_size.Q + Output_Q - 1) / Output_Q; + + int iterations = (n * p * q + problem_size.split_k_slices - 1) / problem_size.split_k_slices; + return iterations; +} + + +CUTLASS_HOST_DEVICE +int implicit_gemm_k_iterations_per_channel( + Operator conv_operator, + Conv2dProblemSize const &problem_size, + IteratorAlgorithm algorithm = IteratorAlgorithm::kAnalytic) { + + int iterations = 0; //0 means not applicable + if (algorithm == IteratorAlgorithm::kAnalytic || algorithm == IteratorAlgorithm::kOptimized) { + switch (conv_operator) { + case Operator::kFprop: + iterations = problem_size.R * problem_size.S; + break; + + case Operator::kDeconv: + case Operator::kDgrad: + iterations = problem_size.R * problem_size.S; + break; + + default: + break; + } + } + return iterations; +} + //////////////////////////////////////////////////////////////////////////////// // Mapping function (ImplicitGemm A, B, C -> Conv Activation, Filter, Output) //////////////////////////////////////////////////////////////////////////////// @@ -427,6 +521,7 @@ cutlass::Tensor4DCoord implicit_gemm_tensor_a_extent( Conv2dProblemSize const &problem_size) { switch (conv_operator) { case cutlass::conv::Operator::kFprop: return problem_size.activation_extent(); + case cutlass::conv::Operator::kDeconv: case cutlass::conv::Operator::kDgrad: return problem_size.output_extent(); case cutlass::conv::Operator::kWgrad: return problem_size.output_extent(); default : break; @@ -441,6 +536,7 @@ cutlass::Tensor4DCoord implicit_gemm_tensor_b_extent( Conv2dProblemSize const &problem_size) { switch (conv_operator) { case cutlass::conv::Operator::kFprop: return problem_size.filter_extent(); + case cutlass::conv::Operator::kDeconv: return problem_size.filter_extent(true); case cutlass::conv::Operator::kDgrad: return problem_size.filter_extent(); case cutlass::conv::Operator::kWgrad: return problem_size.activation_extent(); default : break; @@ -455,6 +551,7 @@ cutlass::Tensor4DCoord implicit_gemm_tensor_c_extent( Conv2dProblemSize const &problem_size) { switch (conv_operator) { case cutlass::conv::Operator::kFprop: return problem_size.output_extent(); + case cutlass::conv::Operator::kDeconv: case cutlass::conv::Operator::kDgrad: return problem_size.activation_extent(); case cutlass::conv::Operator::kWgrad: return problem_size.filter_extent(); default : break; @@ -469,6 +566,7 @@ int64_t implicit_gemm_tensor_a_size( Conv2dProblemSize const &problem_size) { switch (conv_operator) { case cutlass::conv::Operator::kFprop: return problem_size.activation_size(); + case cutlass::conv::Operator::kDeconv: case cutlass::conv::Operator::kDgrad: return problem_size.output_size(); case cutlass::conv::Operator::kWgrad: return problem_size.output_size(); default : break; @@ -483,6 +581,7 @@ int64_t implicit_gemm_tensor_b_size( Conv2dProblemSize const &problem_size) { switch (conv_operator) { case cutlass::conv::Operator::kFprop: return problem_size.filter_size(); + case cutlass::conv::Operator::kDeconv: case cutlass::conv::Operator::kDgrad: return problem_size.filter_size(); case cutlass::conv::Operator::kWgrad: return problem_size.activation_size(); default : break; @@ -497,6 +596,7 @@ int64_t implicit_gemm_tensor_c_size( Conv2dProblemSize const &problem_size) { switch (conv_operator) { case cutlass::conv::Operator::kFprop: return problem_size.output_size(); + case cutlass::conv::Operator::kDeconv: case cutlass::conv::Operator::kDgrad: return problem_size.activation_size(); case cutlass::conv::Operator::kWgrad: return problem_size.filter_size(); default : break; @@ -537,12 +637,12 @@ void strided_dgrad_starting_coords( // function locals for remainder by fast divmod int pad_h_rem_, pad_w_rem_; - // start_h = platform::abs(problem_size.stride_h - ((problem_size.pad_h % problem_size.stride_h) - r)) % problem_size.stride_h; + // start_h = std::abs(problem_size.stride_h - ((problem_size.pad_h % problem_size.stride_h) - r)) % problem_size.stride_h; stride_h_divmod.divmod(pad_h_rem_, problem_size.pad_h); int r_ = absolute_value(problem_size.stride_h - (pad_h_rem_ - r)); stride_h_divmod.divmod(start_h, r_); - //start_w = platform::abs(problem_size.stride_w - ((problem_size.pad_w % problem_size.stride_w) - s)) % problem_size.stride_w; + //start_w = std::abs(problem_size.stride_w - ((problem_size.pad_w % problem_size.stride_w) - s)) % problem_size.stride_w; stride_w_divmod.divmod(pad_w_rem_, problem_size.pad_w); int s_ = absolute_value(problem_size.stride_w - (pad_w_rem_ - s)); stride_w_divmod.divmod(start_w, s_); diff --git a/include/cutlass/conv/conv3d_problem_size.h b/include/cutlass/conv/conv3d_problem_size.h index 82ea1cef46..9a9514f2d8 100644 --- a/include/cutlass/conv/conv3d_problem_size.h +++ b/include/cutlass/conv/conv3d_problem_size.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -80,11 +80,11 @@ struct Conv3dProblemSize : public Conv2dProblemSize { public: CUTLASS_HOST_DEVICE Conv3dProblemSize(): + Conv2dProblemSize(), D(0), T(0), Z(0), - pad_d(0), + pad_d(0), stride_d(1), - dilation_d(1), - Conv2dProblemSize() { } + dilation_d(1) { } /// Constructor for default padding, stride, dilation, and split-K CUTLASS_HOST_DEVICE @@ -102,10 +102,10 @@ struct Conv3dProblemSize : public Conv2dProblemSize { int R, int S, Mode mode - ): + ): + Conv2dProblemSize(N, H, W, C, P, Q, K, R, S, mode), D(D), T(T), Z(Z), - pad_d(T / 2), stride_d(1), dilation_d(1), - Conv2dProblemSize(N, H, W, C, P, Q, K, R, S, mode) { } + pad_d(T / 2), stride_d(1), dilation_d(1) { } /// Constructor CUTLASS_HOST_DEVICE @@ -134,15 +134,15 @@ struct Conv3dProblemSize : public Conv2dProblemSize { Mode mode, int split_k_slices = 1, int groups = 1 - ): - D(D), T(T), Z(Z), - pad_d(pad_d), stride_d(stride_d), dilation_d(dilation_d), + ): Conv2dProblemSize( - N, H, W, C, K, R, S, P, Q, - pad_h, pad_w, - stride_h, stride_w, - dilation_h, dilation_w, - mode, split_k_slices, groups) { } + N, H, W, C, K, R, S, P, Q, + pad_h, pad_w, + stride_h, stride_w, + dilation_h, dilation_w, + mode, split_k_slices, groups), + D(D), T(T), Z(Z), + pad_d(pad_d), stride_d(stride_d), dilation_d(dilation_d) { } /// Constructs convolution problem size from cutlass Tensor5DCoord and Coord3D // set *user-defined* output size and sets Z, P, and Q (include all data members in ctor) @@ -158,8 +158,6 @@ struct Conv3dProblemSize : public Conv2dProblemSize { int split_k_slices = 1, int groups = 1 ): - D(input_size.d()), T(filter_size.d()), Z(output_size.d()), - pad_d(padding[0]), stride_d(stride[0]), dilation_d(dilation[0]), Conv2dProblemSize( {input_size.n(), input_size.h(), input_size.w(), input_size.c()}, {filter_size.n(), filter_size.h(), filter_size.w(), filter_size.c()}, @@ -167,8 +165,9 @@ struct Conv3dProblemSize : public Conv2dProblemSize { {stride[1], stride[2]}, {dilation[1], dilation[2]}, {output_size.n(), output_size.h(), output_size.w(), output_size.c()}, - mode, split_k_slices, groups - ) { } + mode, split_k_slices, groups), + D(input_size.d()), T(filter_size.d()), Z(output_size.d()), + pad_d(padding[0]), stride_d(stride[0]), dilation_d(dilation[0]) { } /// Constructs convolution problem size from cutlass Tensor5DCoord and Coord3D // *computes* output size and sets Z, P and Q (include all data members in ctor) @@ -183,18 +182,46 @@ struct Conv3dProblemSize : public Conv2dProblemSize { int split_k_slices = 1, int groups = 1 ): - D(input_size.d()), T(filter_size.d()), - pad_d(padding[0]), stride_d(stride[0]), dilation_d(dilation[0]), Conv2dProblemSize( {input_size.n(), input_size.h(), input_size.w(), input_size.c()}, {filter_size.n(), filter_size.h(), filter_size.w(), filter_size.c()}, {padding[1], padding[1], padding[2], padding[2]}, {stride[1], stride[2]}, {dilation[1], dilation[2]}, - mode, split_k_slices, groups - ) { + mode, split_k_slices, groups), + D(input_size.d()), T(filter_size.d()), + pad_d(padding[0]), stride_d(stride[0]), dilation_d(dilation[0]) + { + // set output Z + Z = ((D + pad_d * 2 - T * dilation_d) / stride_d) + 1; + } + + /// Constructs convolution problem size from cutlass Tensor5DCoord, Coord3D + // *computes* output size and sets Z, P and Q (include all data members in ctor) + CUTLASS_HOST_DEVICE + Conv3dProblemSize( + cutlass::Tensor5DCoord input_size, // NDHWC + cutlass::Tensor5DCoord filter_size, // KTRSC + CUTLASS_STL_NAMESPACE::tuple padding, // Coord3D {pad_d, pad_h, pad_w} & Coord3D {far pad_d, pad_h, pad_w} to calculate o/p/q + Coord3D stride, // stride_d, stride_h, stride_w + Coord3D dilation, // dilation_d, dilation_h, dilation_w + cutlass::conv::Mode mode = cutlass::conv::Mode::kCrossCorrelation, + int split_k_slices = 1, + int groups = 1 + ): + Conv2dProblemSize( + {input_size.n(), input_size.h(), input_size.w(), input_size.c()}, + {filter_size.n(), filter_size.h(), filter_size.w(), filter_size.c()}, + {CUTLASS_STL_NAMESPACE::get<0>(padding)[1], CUTLASS_STL_NAMESPACE::get<1>(padding)[1], + CUTLASS_STL_NAMESPACE::get<0>(padding)[2], CUTLASS_STL_NAMESPACE::get<1>(padding)[2]}, + {stride[1], stride[2]}, + {dilation[1], dilation[2]}, + mode, split_k_slices, groups), + D(input_size.d()), T(filter_size.d()), + pad_d(CUTLASS_STL_NAMESPACE::get<0>(padding)[0]), stride_d(stride[0]), dilation_d(dilation[0]) + { // set output Z - Z = ((D + pad_d * 2 - T * dilation_d) / stride_d) + 1; + Z = ((D + pad_d + CUTLASS_STL_NAMESPACE::get<1>(padding)[0] - T * dilation_d) / stride_d) + 1; } /// Equality operator (ignores mode and split_k_slice) @@ -205,8 +232,8 @@ struct Conv3dProblemSize : public Conv2dProblemSize { (K == conv.K) && (T == conv.T) && (R == conv.R) && (S == conv.S) && (Z == conv.Z) &&(P == conv.P) && (Q == conv.Q) && (pad_d == conv.pad_d) && (pad_h == conv.pad_h) && (pad_w == conv.pad_w) && - (stride_d == conv.stride_d) && (stride_h == conv.stride_h) && (stride_w == conv.stride_h) && - (dilation_d == conv.dilation_d) && (dilation_h == conv.dilation_h) && (dilation_h == conv.dilation_h) + (stride_d == conv.stride_d) && (stride_h == conv.stride_h) && (stride_w == conv.stride_w) && + (dilation_d == conv.dilation_d) && (dilation_h == conv.dilation_h) && (dilation_w == conv.dilation_w) ); } @@ -241,9 +268,10 @@ struct Conv3dProblemSize : public Conv2dProblemSize { /// Returns filter extent as Tensor5DCoord CUTLASS_HOST_DEVICE - cutlass::Tensor5DCoord filter_extent() const { + cutlass::Tensor5DCoord filter_extent(bool is_deconv = false) const { - return cutlass::Tensor5DCoord ({K, T, R, S, C}); + return is_deconv ? cutlass::Tensor5DCoord ({C, T, R, S, K}) + : cutlass::Tensor5DCoord ({K, T, R, S, C}); } /// Returns output extent as Tensor5DCoord @@ -274,7 +302,7 @@ struct Conv3dProblemSize : public Conv2dProblemSize { return (N * Z * P * Q * K); } - /// Returns output extent as Tensor5DCoord + /// Returns padding as Coord3D CUTLASS_HOST_DEVICE Coord3D padding() const { @@ -315,6 +343,7 @@ cutlass::gemm::GemmCoord implicit_gemm_problem_size( problem_size.K, problem_size.T * problem_size.R * problem_size.S * problem_size.C ); + case Operator::kDeconv: case Operator::kDgrad: return gemm::GemmCoord( problem_size.N * problem_size.D * problem_size.H * problem_size.W, @@ -339,29 +368,47 @@ int implicit_gemm_k_iterations( Operator conv_operator, int threadblock_K, Conv3dProblemSize const &problem_size, - IteratorAlgorithm algorithm = IteratorAlgorithm::kAnalytic) { + IteratorAlgorithm algorithm = IteratorAlgorithm::kAnalytic, + GroupMode group_mode = GroupMode::kNone, + int threadblock_N = 0) { int iterations = 0; int elements_per_split_k_slice = 0; - - switch (conv_operator) { - case Operator::kFprop: - elements_per_split_k_slice = (problem_size.C + problem_size.split_k_slices - 1) / problem_size.split_k_slices; - iterations = problem_size.T * problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K); - break; - - case Operator::kDgrad: - elements_per_split_k_slice = (problem_size.K + problem_size.split_k_slices - 1) / problem_size.split_k_slices; - iterations = problem_size.T * problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K); - break; - - case Operator::kWgrad: - elements_per_split_k_slice = (problem_size.N * problem_size.Z * problem_size.P * problem_size.Q + problem_size.split_k_slices - 1) / problem_size.split_k_slices; - iterations = (elements_per_split_k_slice + threadblock_K - 1) / threadblock_K; - break; - - default: - break; + if (group_mode == GroupMode::kNone) { + switch (conv_operator) { + case Operator::kFprop: + elements_per_split_k_slice = (problem_size.C + problem_size.split_k_slices - 1) / problem_size.split_k_slices; + iterations = problem_size.T * problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K); + break; + + case Operator::kDeconv: + case Operator::kDgrad: + elements_per_split_k_slice = (problem_size.K + problem_size.split_k_slices - 1) / problem_size.split_k_slices; + iterations = problem_size.T * problem_size.R * problem_size.S * ((elements_per_split_k_slice + threadblock_K - 1) / threadblock_K); + break; + + case Operator::kWgrad: + elements_per_split_k_slice = (problem_size.N * problem_size.Z * problem_size.P * problem_size.Q + problem_size.split_k_slices - 1) / problem_size.split_k_slices; + iterations = (elements_per_split_k_slice + threadblock_K - 1) / threadblock_K; + break; + + default: + break; + } + } else if (group_mode == GroupMode::kDepthwise) { + int channels_per_cta = threadblock_N; + + if (algorithm == IteratorAlgorithm::kAnalytic) { + switch (conv_operator) { + case Operator::kFprop: + iterations = problem_size.T * problem_size.R * problem_size.S * + ((channels_per_cta + threadblock_K - 1) / threadblock_K); + break; + + default: + break; + } + } } return iterations; @@ -377,6 +424,7 @@ cutlass::Tensor5DCoord implicit_gemm_tensor_a_extent( Conv3dProblemSize const &problem_size) { switch (conv_operator) { case cutlass::conv::Operator::kFprop: return problem_size.activation_extent(); + case cutlass::conv::Operator::kDeconv: case cutlass::conv::Operator::kDgrad: return problem_size.output_extent(); case cutlass::conv::Operator::kWgrad: return problem_size.output_extent(); default : break; @@ -391,6 +439,7 @@ cutlass::Tensor5DCoord implicit_gemm_tensor_b_extent( Conv3dProblemSize const &problem_size) { switch (conv_operator) { case cutlass::conv::Operator::kFprop: return problem_size.filter_extent(); + case cutlass::conv::Operator::kDeconv: return problem_size.filter_extent(true); case cutlass::conv::Operator::kDgrad: return problem_size.filter_extent(); case cutlass::conv::Operator::kWgrad: return problem_size.activation_extent(); default : break; @@ -405,6 +454,7 @@ cutlass::Tensor5DCoord implicit_gemm_tensor_c_extent( Conv3dProblemSize const &problem_size) { switch (conv_operator) { case cutlass::conv::Operator::kFprop: return problem_size.output_extent(); + case cutlass::conv::Operator::kDeconv: case cutlass::conv::Operator::kDgrad: return problem_size.activation_extent(); case cutlass::conv::Operator::kWgrad: return problem_size.filter_extent(); default : break; @@ -419,6 +469,7 @@ int64_t implicit_gemm_tensor_a_size( Conv3dProblemSize const &problem_size) { switch (conv_operator) { case cutlass::conv::Operator::kFprop: return problem_size.activation_size(); + case cutlass::conv::Operator::kDeconv: case cutlass::conv::Operator::kDgrad: return problem_size.output_size(); case cutlass::conv::Operator::kWgrad: return problem_size.output_size(); default : break; @@ -433,6 +484,7 @@ int64_t implicit_gemm_tensor_b_size( Conv3dProblemSize const &problem_size) { switch (conv_operator) { case cutlass::conv::Operator::kFprop: return problem_size.filter_size(); + case cutlass::conv::Operator::kDeconv: case cutlass::conv::Operator::kDgrad: return problem_size.filter_size(); case cutlass::conv::Operator::kWgrad: return problem_size.activation_size(); default : break; @@ -447,6 +499,7 @@ int64_t implicit_gemm_tensor_c_size( Conv3dProblemSize const &problem_size) { switch (conv_operator) { case cutlass::conv::Operator::kFprop: return problem_size.output_size(); + case cutlass::conv::Operator::kDeconv: case cutlass::conv::Operator::kDgrad: return problem_size.activation_size(); case cutlass::conv::Operator::kWgrad: return problem_size.filter_size(); default : break; diff --git a/include/cutlass/conv/convnd_problem_shape.hpp b/include/cutlass/conv/convnd_problem_shape.hpp new file mode 100644 index 0000000000..cd2f674ff4 --- /dev/null +++ b/include/cutlass/conv/convnd_problem_shape.hpp @@ -0,0 +1,601 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief This file contains definitions and utility functions for describing convolution problem shapes. +*/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/tensor_coord.h" +#include "cutlass/conv/convolution.h" + +#include "cute/container/array.hpp" + +#if ! defined(__CUDACC_RTC__) +#include +#endif + + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::conv { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Implements the user facing argument for all CUTLASS 3.x convolutions in a rank agnostic fashion. +// All tensors are flat and by default treated as layout right (NDHWC, KTRSC, NZPQK) +// Supports asymmetric padding, traversal strides, dilations, and all conv algorithm types. +template < + conv::Operator ConvOp_, + int NumSpatialDimensions_ +> +struct ConvProblemShape { + // + // Alias types for members + // + + static constexpr int RankS = NumSpatialDimensions_; + static constexpr int RankT = NumSpatialDimensions_ + 2; + static constexpr conv::Operator ConvOp = ConvOp_; + static constexpr int NumSpatialDimensions = NumSpatialDimensions_; + using SpatialExtent = cute::array; + using TensorExtent = cute::array; + using TensorStride = cute::array; + using ShapePadding = SpatialExtent; + using TraversalStride = SpatialExtent; + using ShapeDilation = SpatialExtent; + using Corner = SpatialExtent; + + // + // Members + // + cutlass::conv::Mode mode{}; + TensorExtent shape_A{}; + TensorStride stride_A{}; + TensorExtent shape_B{}; + TensorStride stride_B{}; + TensorExtent shape_C{}; + TensorStride stride_C{}; + + // asymmetric padding, both upper and lower padding must be >= 0 + ShapePadding lower_padding{}; + ShapePadding upper_padding{}; + TraversalStride traversal_stride{}; + ShapeDilation dilation{}; + int groups = 1; + + // + // Methods + // + + ConvProblemShape() = default; + + // Constructor accepts user facing arguments and computes to stores the corners as its internal state + ConvProblemShape( + conv::Mode mode, // convolution/cross-correlation + TensorExtent shape_act, // [n,d,h,w,c] + TensorStride stride_act, // [n,d,h,w,c] + TensorExtent shape_flt, // [k,t,r,s,c] + TensorStride stride_flt, // [k,t,r,s,c] + ShapePadding lower_padding, // [pad_d, pad_h, pad_w] + ShapePadding upper_padding, // [pad_d, pad_h, pad_w] + TraversalStride tstride, // [stride_d, stride_h, stride_w] + ShapeDilation dilation, // [dilation_d, dilation_h, dilation_w] + int groups) + : mode(mode) + , lower_padding(lower_padding) + , upper_padding(upper_padding) + , traversal_stride(tstride) + , dilation(dilation) + , groups(groups) { + + auto [shape_xformed_act, stride_xformed_act] = calculate_xformed_act(shape_act, shape_flt); + set_shape_stride_ABC(shape_act, stride_act, shape_flt, stride_flt, shape_xformed_act, stride_xformed_act); + } + + // Allow user input of xformed activation stride to support non-packed strides. + ConvProblemShape( + conv::Mode mode, // convolution/cross-correlation + TensorExtent shape_act, // [n,d,h,w,c] + TensorStride stride_act, // [n,d,h,w,c] + TensorExtent shape_flt, // [k,t,r,s,c] + TensorStride stride_flt, // [k,t,r,s,c] + TensorStride stride_xformed_act, // [n,z,p,q,k] + ShapePadding lower_padding, // [pad_d, pad_h, pad_w] + ShapePadding upper_padding, // [pad_d, pad_h, pad_w] + TraversalStride tstride, // [stride_d, stride_h, stride_w] + ShapeDilation dilation, // [dilation_d, dilation_h, dilation_w] + int groups) + : mode(mode) + , lower_padding(lower_padding) + , upper_padding(upper_padding) + , traversal_stride(tstride) + , dilation(dilation) + , groups(groups) { + + CUTLASS_ASSERT(stride_act[RankT - 1] == 1); + CUTLASS_ASSERT(stride_flt[RankT - 1] == 1); + CUTLASS_ASSERT(stride_xformed_act[RankT - 1] == 1); + + auto stride_act_packed = packed_stride_right_major(shape_act); + auto stride_flt_packed = packed_stride_right_major(shape_flt); + auto [shape_xformed_act, stride_xformed_act_packed] = calculate_xformed_act(shape_act, shape_flt); + + CUTLASS_PRAGMA_UNROLL + for(int i = 0; i < RankT - 1; ++i) { + CUTLASS_ASSERT(stride_act[i] >= stride_act_packed[i]); + CUTLASS_ASSERT(stride_flt[i] >= stride_flt_packed[i]); + CUTLASS_ASSERT(stride_xformed_act[i] >= stride_xformed_act_packed[i]); + } + + set_shape_stride_ABC(shape_act, stride_act, shape_flt, stride_flt, shape_xformed_act, stride_xformed_act); + } + + // Constructor accepts user facing arguments and presume packed tensor strides in canonical (CWHDN) order. + ConvProblemShape( + conv::Mode mode, + TensorExtent shape_act, + TensorExtent shape_flt, + ShapePadding lower_padding, + ShapePadding upper_padding, + TraversalStride tstride, + ShapeDilation dilation, + int groups) + : ConvProblemShape( + mode, + shape_act, + packed_stride_right_major(shape_act), + shape_flt, + packed_stride_right_major(shape_flt), + lower_padding, + upper_padding, + tstride, + dilation, + groups) { + } + +#if ! defined(__CUDACC_RTC__) + // Constructor accepts user facing arguments and computes to stores the corners as its internal state + ConvProblemShape( + conv::Mode mode, + std::initializer_list shape_act_, + std::initializer_list stride_act_, + std::initializer_list shape_flt_, + std::initializer_list stride_flt_, + std::initializer_list lower_padding_, + std::initializer_list upper_padding_, + std::initializer_list traversal_stride_, + std::initializer_list dilation_, + int groups) + : mode(mode) + , groups(groups) { + + TensorExtent shape_act{}; + TensorStride stride_act{}; + TensorExtent shape_flt{}; + TensorStride stride_flt{}; + + assert(shape_act_.size() == shape_act.size()); + assert(stride_act_.size() == stride_act.size()); + assert(shape_flt_.size() == shape_flt.size()); + assert(stride_flt_.size() == stride_flt.size()); + assert(lower_padding_.size() == lower_padding.size()); + assert(upper_padding_.size() == upper_padding.size()); + assert(traversal_stride_.size() == traversal_stride.size()); + assert(dilation_.size() == dilation.size()); + + std::copy(shape_act_.begin(), shape_act_.end(), shape_act.begin()); + std::copy(stride_act_.begin(), stride_act_.end(), stride_act.begin()); + std::copy(shape_flt_.begin(), shape_flt_.end(), shape_flt.begin()); + std::copy(stride_flt_.begin(), stride_flt_.end(), stride_flt.begin()); + std::copy(lower_padding_.begin(), lower_padding_.end(), lower_padding.begin()); + std::copy(upper_padding_.begin(), upper_padding_.end(), upper_padding.begin()); + std::copy(traversal_stride_.begin(), traversal_stride_.end(), traversal_stride.begin()); + std::copy(dilation_.begin(), dilation_.end(), dilation.begin()); + + auto [shape_xformed_act, stride_xformed_act] = calculate_xformed_act(shape_act, shape_flt); + set_shape_stride_ABC(shape_act, stride_act, shape_flt, stride_flt, shape_xformed_act, stride_xformed_act); + } + + // Allow user input of xformed activation stride to support non-packed strides. + ConvProblemShape( + conv::Mode mode, + std::initializer_list shape_act_, + std::initializer_list stride_act_, + std::initializer_list shape_flt_, + std::initializer_list stride_flt_, + std::initializer_list stride_xformed_act_, + std::initializer_list lower_padding_, + std::initializer_list upper_padding_, + std::initializer_list traversal_stride_, + std::initializer_list dilation_, + int groups) + : mode(mode) + , groups(groups) { + TensorExtent shape_act{}; + TensorStride stride_act{}; + TensorExtent shape_flt{}; + TensorStride stride_flt{}; + TensorStride stride_xformed_act{}; + + std::copy(shape_act_.begin(), shape_act_.end(), shape_act.begin()); + std::copy(stride_act_.begin(), stride_act_.end(), stride_act.begin()); + std::copy(shape_flt_.begin(), shape_flt_.end(), shape_flt.begin()); + std::copy(stride_flt_.begin(), stride_flt_.end(), stride_flt.begin()); + std::copy(stride_xformed_act_.begin(), stride_xformed_act_.end(), stride_xformed_act.begin()); + std::copy(lower_padding_.begin(), lower_padding_.end(), lower_padding.begin()); + std::copy(upper_padding_.begin(), upper_padding_.end(), upper_padding.begin()); + std::copy(traversal_stride_.begin(), traversal_stride_.end(), traversal_stride.begin()); + std::copy(dilation_.begin(), dilation_.end(), dilation.begin()); + + CUTLASS_ASSERT(stride_act[RankT - 1] == 1); + CUTLASS_ASSERT(stride_flt[RankT - 1] == 1); + CUTLASS_ASSERT(stride_xformed_act[RankT - 1] == 1); + + auto stride_act_packed = packed_stride_right_major(shape_act); + auto stride_flt_packed = packed_stride_right_major(shape_flt); + auto [shape_xformed_act, stride_xformed_act_packed] = calculate_xformed_act(shape_act, shape_flt); + + CUTLASS_PRAGMA_UNROLL + for(int i = 0; i < RankT - 1; ++i) { + CUTLASS_ASSERT(stride_act[i] >= stride_act_packed[i]); + CUTLASS_ASSERT(stride_flt[i] >= stride_flt_packed[i]); + CUTLASS_ASSERT(stride_xformed_act[i] >= stride_xformed_act_packed[i]); + } + + set_shape_stride_ABC(shape_act, stride_act, shape_flt, stride_flt, shape_xformed_act, stride_xformed_act); + } + + // Constructor accepts user facing arguments and computes to stores the corners as its internal state + ConvProblemShape( + conv::Mode mode, + std::initializer_list shape_act_, + std::initializer_list shape_flt_, + std::initializer_list lower_padding_, + std::initializer_list upper_padding_, + std::initializer_list traversal_stride_, + std::initializer_list dilation_, + int groups) + : mode(mode) + , groups(groups) { + TensorExtent shape_act{}; + TensorStride stride_act{}; + TensorExtent shape_flt{}; + TensorStride stride_flt{}; + + assert(shape_act_.size() == shape_act.size()); + assert(shape_flt_.size() == shape_flt.size()); + assert(lower_padding_.size() == lower_padding.size()); + assert(upper_padding_.size() == upper_padding.size()); + assert(traversal_stride_.size() == traversal_stride.size()); + assert(dilation_.size() == dilation.size()); + + std::copy(shape_act_.begin(), shape_act_.end(), shape_act.begin()); + std::copy(shape_flt_.begin(), shape_flt_.end(), shape_flt.begin()); + std::copy(lower_padding_.begin(), lower_padding_.end(), lower_padding.begin()); + std::copy(upper_padding_.begin(), upper_padding_.end(), upper_padding.begin()); + std::copy(traversal_stride_.begin(), traversal_stride_.end(), traversal_stride.begin()); + std::copy(dilation_.begin(), dilation_.end(), dilation.begin()); + stride_act = packed_stride_right_major(shape_act); + stride_flt = packed_stride_right_major(shape_flt); + + auto [shape_xformed_act, stride_xformed_act] = calculate_xformed_act(shape_act, shape_flt); + set_shape_stride_ABC(shape_act, stride_act, shape_flt, stride_flt, shape_xformed_act, stride_xformed_act); + } +#endif // not defined(__CUDACC_RTC__) + + // Set shape and stride of tensor A/B/C according to following table: + // | | Fprop | Dgrad | Wgrad | + // | ------ | ------ | ------ | ------| + // | ShapeA | NDHWC | NZPQK | NZPQK | + // | ShapeB | KTRSC | KTRSC | NDHWC | + // | ShapeC | NZPQK | NDHWC | KTRSC | + // + // Input comes from calculate_xformed_act, which does NOT depend on ConvOp. + CUTLASS_HOST_DEVICE + constexpr void + set_shape_stride_ABC( + TensorExtent shape_act, + TensorStride stride_act, + TensorExtent shape_flt, + TensorStride stride_flt, + TensorExtent shape_xformed_act, + TensorStride stride_xformed_act) { +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + printf("*** set_shape_stride_ABC ***"); + printf("\n shape_act: "); + print(shape_act); + printf("\n stride_act: "); + print(stride_act); + printf("\n shape_flt: "); + print(shape_flt); + printf("\n stride_flt: "); + print(stride_flt); + printf("\n shape_xformed_act: "); + print(shape_xformed_act); + printf("\n stride_xformed_act: "); + print(stride_xformed_act); + if constexpr (ConvOp == cutlass::conv::Operator::kFprop) { + printf("\n ConvOp: Fprop"); + } + if constexpr (ConvOp == cutlass::conv::Operator::kDgrad) { + printf("\n ConvOp: Dgrad"); + } + if constexpr (ConvOp == cutlass::conv::Operator::kWgrad) { + printf("\n ConvOp: Wgrad"); + } + printf("\n"); +#endif + + if constexpr (ConvOp == cutlass::conv::Operator::kFprop) { + shape_A = shape_act; + stride_A = stride_act; + shape_B = shape_flt; + stride_B = stride_flt; + shape_C = shape_xformed_act; + stride_C = stride_xformed_act; + } + else if constexpr (ConvOp == cutlass::conv::Operator::kDgrad) { + shape_A = shape_xformed_act; + stride_A = stride_xformed_act; + shape_B = shape_flt; + stride_B = stride_flt; + shape_C = shape_act; + stride_C = stride_act; + } + else if constexpr (ConvOp == cutlass::conv::Operator::kWgrad) { + shape_A = shape_xformed_act; + stride_A = stride_xformed_act; + shape_B = shape_act; + stride_B = stride_act; + shape_C = shape_flt; + stride_C = stride_flt; + } +#if defined(CUTLASS_DEBUG_TRACE_LEVEL) && (CUTLASS_DEBUG_TRACE_LEVEL > 1) + printf("\n shape_A: "); + print(shape_A); + printf("\n stride_A: "); + print(stride_A); + printf("\n shape_B: "); + print(shape_B); + printf("\n stride_B: "); + print(stride_B); + printf("\n shape_C: "); + print(shape_C); + printf("\n stride_C: "); + print(stride_C); +#endif + } + + // Get A extents. + // fprop: A extents array contains [N,D,H,W,C]. Turn that into ((W,H,D,N), (C)) + // dgrad: A extents array contains [N,Z,P,Q,K]. Turn that into ((Q,P,Z,N), (K)) + // wgrad: A extents array contains [N,Z,P,Q,K]. Turn that into ((K), (Q,P,Z,N)) + CUTLASS_HOST_DEVICE + constexpr auto + get_shape_A() const { + using cute::make_shape; + using cute::take; + + if constexpr (ConvOp == conv::Operator::kFprop || + ConvOp == conv::Operator::kDgrad) { + return make_shape( + cute::reverse(take<0, RankT - 1>(shape_A)), + shape_A[RankT - 1]); + } + // For wgrad kernel, we need to linearize NZPQ for tensor A + else if constexpr (ConvOp == conv::Operator::kWgrad) { + return make_shape( + shape_A[RankT - 1], + cute::product(take<0, RankT - 1>(shape_A))); + } + } + + // Get B extents. + // fprop: B extents array contains [K,T,R,S,C]. Turn that into ((K), (C,S,R,T)) + // dgrad: B extents array contains [K,T,R,S,C]. Turn that into ((C), (K,S,R,T)) + // wgrad: B extents array contains [N,D,H,W,C]. Turn that into ((C), (W,H,D,N)) + CUTLASS_HOST_DEVICE + constexpr auto + get_shape_B() const { + using cute::make_shape; + using cute::reverse; + using cute::take; + + if constexpr (ConvOp == conv::Operator::kFprop) { + return make_shape( + shape_B[0], + reverse(take<1, RankT>(shape_B))); + } + else if constexpr (ConvOp == conv::Operator::kWgrad) { + return make_shape( + shape_B[RankT - 1], + reverse(take<0, RankT - 1>(shape_B))); + } + else if constexpr (ConvOp == conv::Operator::kDgrad) { + // shape_B: [K,T,R,S,C], return: [(C),(K,S,R,T)] + return make_shape( + shape_B[RankT - 1], + cute::insert<0>( + reverse(take<1, RankT - 1>(shape_B)), + shape_B[0])); + } + } + + // Get C extents. + // fprop: C extents array contains [N,Z,P,Q,K]. Turn that into ((Q,P,Z,N), (K)) + // dgrad: C extents array contains [N,D,H,W,C]. Turn that into ((W,H,D,N), (C)) + // wgrad: C extents array contains [K,T,R,S,C]. Turn that into ((K), (C,S,R,T)) + CUTLASS_HOST_DEVICE + constexpr auto + get_shape_C() const { + using cute::make_shape; + using cute::reverse; + using cute::take; + + if constexpr (ConvOp == conv::Operator::kFprop || + ConvOp == conv::Operator::kDgrad) { + return make_shape( + reverse(take<0, RankT - 1>(shape_C)), + shape_C[RankT - 1]); + } + else if constexpr (ConvOp == conv::Operator::kWgrad) { + return make_shape( + shape_C[0], + reverse(take<1, RankT>(shape_C))); + } + } + + // Static method that returns the canonical strides of tensors (layouts are right major and compact) + CUTLASS_HOST_DEVICE + static constexpr TensorStride + packed_stride_right_major(TensorExtent const& extents) { + TensorStride strides{}; + strides[RankT-1] = 1; + cute::for_each(cute::make_rseq{}, [&](auto i) { + strides[i] = extents[i+1] * strides[i+1]; + }); + return strides; + } + + // Static method that returns the packed logical size of any TensorExtent + CUTLASS_HOST_DEVICE + static constexpr size_t + size(TensorExtent const& extents) { + size_t size = 1; + cute::for_each(cute::make_seq{}, [&](auto i) { + size *= extents[i]; + }); + return size; + } + + CUTLASS_HOST_DEVICE + constexpr size_t + size_A() const { + return shape_A[0] * stride_A[0]; + } + + CUTLASS_HOST_DEVICE + constexpr size_t + size_B() const { + return shape_B[0] * stride_B[0]; + } + + CUTLASS_HOST_DEVICE + constexpr size_t + size_C() const { + return shape_C[0] * stride_C[0]; + } + + // Equality operator + CUTLASS_HOST_DEVICE + bool operator==(ConvProblemShape const& rhs) const { + using cute::for_each; + using cute::make_seq; + + bool is_equal = true; + + // Compare all tensor extents + for_each(make_seq{}, [&](auto i) { + is_equal = is_equal + && (shape_A[i] == rhs.shape_A[i]) + && (shape_B[i] == rhs.shape_B[i]); + }); + + // Compare all spatial extents + for_each(make_seq{}, [&](auto i) { + is_equal = is_equal + && (lower_padding[i] == rhs.lower_padding[i]) + && (upper_padding[i] == rhs.upper_padding[i]) + && (traversal_stride[i] == rhs.traversal_stride[i]) + && (dilation[i] == rhs.dilation[i]); + }); + + return is_equal; + } + + /// Inequality operator + CUTLASS_HOST_DEVICE + bool operator!=(ConvProblemShape const &rhs) const { + return !(*this == rhs); + } + +private: + CUTLASS_HOST_DEVICE + constexpr auto + calculate_xformed_act(TensorExtent shape_act, TensorExtent shape_flt) { + TensorExtent shape_xformed_act{}; + // calculate n,z,p,q,k. + // a helper lambda to compute a single spatial extent of the nzpqk tensor + auto nzpqk_extent = [](int act_ext, int filter_ext, int pad_total, int dilation, int tstride) { + return 1 + (act_ext + pad_total - ((filter_ext -1) * dilation + 1)) / tstride; + }; + + shape_xformed_act[0] = shape_act[0]; // Activation N extent + cute::for_each(cute::make_seq{}, [&](auto i) { + shape_xformed_act[i+1] = nzpqk_extent( + shape_act[i+1], shape_flt[i+1], upper_padding[i] + lower_padding[i], dilation[i], traversal_stride[i]); + }); + shape_xformed_act[RankT-1] = shape_flt[0]; // Filter K extent + + TensorStride stride_xformed_act = packed_stride_right_major(shape_xformed_act); + + return cute::make_tuple(shape_xformed_act, stride_xformed_act); + } +}; + +template< + conv::Operator ConvOp, + int SpatialDim +> +void print(ConvProblemShape const& problem) { + printf("ConvProblemShape with %d spatial dimensions implementing cutlass::conv::Operator::%d\n", + SpatialDim, int(ConvOp)); + printf("\tTensorA: "); + cute::print(problem.shape_A); printf(":"); + cute::print(problem.stride_A); printf("\n"); + printf("\tTensorB: "); + cute::print(problem.shape_B); printf(":"); + cute::print(problem.stride_B); printf("\n"); + printf("\tTensorC: "); + cute::print(problem.shape_C); printf(":"); + cute::print(problem.stride_C); printf("\n"); + printf("\tLower padding: "); print(problem.lower_padding); printf("\n"); + printf("\tUpper padding: "); print(problem.upper_padding); printf("\n"); + printf("\tTraversal strides: "); print(problem.traversal_stride); printf("\n"); + printf("\tDilation: "); print(problem.dilation); printf("\n"); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::conv + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/convolution.h b/include/cutlass/conv/convolution.h index 52a4636c12..243ee269dd 100644 --- a/include/cutlass/conv/convolution.h +++ b/include/cutlass/conv/convolution.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -29,18 +29,18 @@ * **************************************************************************************************/ /*! \file - \brief + \brief -This file contains definitions and utility functions for describing convolution problem sizes in terms of -activation (NHWC), filter (KRSC), output (NPQK), pading (pad_h, pad_w), stride (stride_h, stride_w), -dilation (dilation_h, dilation_w). Furthermore, it defines helper functions to map cutlass' implicit gemm -tensor extents, sizes, data types to that of convolutions extents, sizes, and data types. +This file contains definitions and utility functions for describing convolution problem sizes in terms of +activation (NHWC), filter (KRSC), output (NPQK), padding (pad_h, pad_w), stride (stride_h, stride_w), and +dilation (dilation_h, dilation_w). Furthermore, it defines helper functions to map CUTLASS's implicit gemm +tensor extents, sizes, and data types to that of the convolution's extents, sizes, and data types. * Mapping convolutions to Gemm computation * -Cutlass employs ImplicitGemm algorithm to implement convolutions. ImplicitGemm algorithm runs gemm operation -on convolution tensors Activation, Filter, and Output . The underlying gemm operation follows the standard -gemm definition: +Cutlass implements convolutions with the Implicit Gemm algorithm. This algorithm performs a gemm +(general matrix-matrix multiply) on the convolution tensors Activation, Filter, and Output. +The underlying gemm operation follows the standard gemm definition: C = A * B + C @@ -48,22 +48,23 @@ gemm definition: C is source and output matrix -For the three convolutional operators (Fprop, Dgrad, Wgrad), ImplicitGemm matrices A, B, and C are mapped on -to convolution tensors Activation, Filter and Output as per the below table: +For the three convolutional operators (Fprop, Dgrad, Wgrad), ImplicitGemm matrices A, B, and C are mapped +to convolution tensors Activation, Filter and Output as described in the table below. ___________________________________________________________________________ - ConvolutionalOperator | A | B | C + ConvolutionalOperator | A | B | C ___________________________________________________________________________ | | | | | - | Fprop | Activation | Filter | Output | - | Dgrad | Output | Filter | Activation | - | Wgrad | Output | Activation | Filter | + | Fprop | Activation | Filter | Output | + | Dgrad | Output | Filter | Activation | + | Wgrad | Output | Activation | Filter | ___________________________________________________________________________ -In convolution codebase, DO NOT mix using (A, B, C) with (Acvitation, Filter, Output). +In convolution codebase, DO NOT mix using (A, B, C) with (Activation, Filter, Output). -For example, a convolution class/function with A, B, Output is confusing and error-prone. Instead use below -mapping functions and adhere to using either A, B, C or Acvitation, Filter, Output. +For example, it's confusing and error prone to document a convolution class or function +as operating on "A, B, Output." Instead, use the mapping functions below, +and adhere to using either A, B, C or Activation, Filter, Output. Map elements' data types (ImplicitGemm -> Conv): GemmToConvElementMap Map elements' data types (Conv -> ImplicitGemm): ConvToGemmElementMap @@ -72,9 +73,10 @@ Map elements' data types (Conv -> ImplicitGemm): ConvToGemmElementMap #pragma once #include "cutlass/cutlass.h" +#include "cutlass/layout/tensor.h" #include "cutlass/tensor_coord.h" #include "cutlass/fast_math.h" -#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/gemm_enumerated_types.h" #include "cutlass/matrix_coord.h" namespace cutlass { @@ -83,40 +85,107 @@ namespace conv { //////////////////////////////////////////////////////////////////////////////////////////////////// /// Convolutional operator -enum class Operator { - kFprop, - kDgrad, - kWgrad +enum class Operator { + kFprop, + kDgrad, + kWgrad, + kDeconv }; -/// Distinguishes convolution from cross correlation -enum class Mode { - kCrossCorrelation, - kConvolution +/// Distinguishes convolution from cross correlation +enum class Mode { + kCrossCorrelation, + kConvolution }; /// Selects among several implementation variants trading off performance with simplicity -enum class IteratorAlgorithm { +enum class IteratorAlgorithm { kAnalytic, ///< functionally correct in all cases but lower performance kOptimized, ///< optimized for R <= 32, S <= 32 and unity-stride dgrad kFixedChannels, ///< Analytic algorithm optimized for fixed channel count (C == AccessSize) - kFewChannels ///< Analytic algorithm optimized for few channels (C divisible by AccessSize) + kFewChannels, ///< Analytic algorithm optimized for few channels (C divisible by AccessSize) + kFixedStrideDilation ///< Optimized for fixed stride and dilation }; /// Distinguishes among partial specializations that accelerate certain problems where convolution /// stride is unit. enum class StrideSupport { kStrided, ///< arbitrary convolution stride - kUnity ///< unit convolution stride + kUnity, ///< unit convolution stride + kFixed ///< fixed convolution stride }; /// Identifies split-K mode -enum class SplitKMode { - kNone, - kSerial, +enum class SplitKMode { + kNone, + kSerial, kParallel }; +/// Identifies group mode +enum class GroupMode { + kNone, + kSingleGroup, ///< One CTA calculates one group or less + kMultipleGroup, ///< One CTA calculates multiple groups + kDepthwise ///< One CTA calculates cta_n groups (problem_size.C == problem_size.K == problem_size.groups) +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Shape of a tensor +template < + int N = 1, + int H = 1, + int W = 1, + int C = 1 +> +struct TensorNHWCShape { + static int const kN = N; + static int const kH = H; + static int const kW = W; + static int const kC = C; + + static int const kHW = H * W; + static int const kNHW = N * kHW; + static int const kNHWC = N * H * W * C; + + static int const kCount = kNHWC; + + // + // Static member functions + // + + /// Returns a Coord object + CUTLASS_HOST_DEVICE + static Coord<4> toCoord() { + return make_Coord(kN, kH, kW, kC); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Shape of a conv2d stride, which controls how the filter convolves around the input volume +template < + /// Stride in horizontal direction + int u = 1, + /// Stride in vertical direction + int v = 1 +> +struct Stride2D { + static int const kU = u; + static int const kV = v; + + // + // Static member functions + // + + /// Returns a Coord object + CUTLASS_HOST_DEVICE + static Coord<2> toCoord() { + return make_Coord(kU, kV); + } +}; + //////////////////////////////////////////////////////////////////////////////////////////////////// } // namespace conv diff --git a/include/cutlass/conv/detail.hpp b/include/cutlass/conv/detail.hpp new file mode 100644 index 0000000000..3e4173569c --- /dev/null +++ b/include/cutlass/conv/detail.hpp @@ -0,0 +1,137 @@ + +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/conv/convnd_problem_shape.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::conv::detail { + +///////////////////////////////////////////////////////////////////////////////////////////////// + + // Helper function to get the problem shape +template +auto get_problem_shape_MNKL_helper(ProblemShape const& problem_shape, cute::true_type) { + return T::get_problem_shape_MNKL(problem_shape); +} + +template +ProblemShape get_problem_shape_MNKL_helper(ProblemShape const& problem_shape, cute::false_type) { + return problem_shape; +} + +// Get problem shape MNKL according to following table: +// | | Fprop | Dgrad | Wgrad | +// | ---- | --------- | -------- | -------- | +// | Shape_M | (Q,P,Z,N) | (W/V,H/U,D/O,N) | (K) | +// | Shape_N | (K) | (C) | (C,S,R,T) | +// | Shape_K | (C,S,R,T) | (K,S,R,T) | (Q,P,Z,N) | +// | Shape_L | _1 | (V,U,O) | _1 | + +template +CUTLASS_HOST_DEVICE +constexpr auto +get_transformed_problem_shape_MNKL(ProblemShape const& problem_shape) { + return problem_shape; +} + + +template +CUTLASS_HOST_DEVICE +constexpr auto +get_transformed_problem_shape_MNKL(ConvProblemShape const& problem_shape) { + using cute::insert; + using cute::make_shape; + using cute::reverse; + using cute::take; + + constexpr int RankT = SpatialDim + 2; + + if constexpr (ConvOp == conv::Operator::kWgrad) { + auto M_xformed = problem_shape.shape_C[0]; + auto N_xformed = reverse(take<1, RankT>(problem_shape.shape_C)); + auto K_xformed = reverse(take<0, RankT - 1>(problem_shape.shape_A)); + auto L_xformed = cute::Int<1>{}; + + return make_shape(M_xformed, N_xformed, K_xformed, L_xformed); + } + else if constexpr (ConvOp == conv::Operator::kFprop){ + auto M_xformed = reverse(take<0, RankT - 1>(problem_shape.shape_C)); + auto N_xformed = problem_shape.shape_C[RankT - 1]; + auto K_xformed = reverse(take<1, RankT>(problem_shape.shape_B)); + auto L_xformed = cute::Int<1>{}; + + return make_shape(M_xformed, N_xformed, K_xformed, L_xformed); + } + else if constexpr (ConvOp == conv::Operator::kDgrad) { + auto L_xformed = reverse(problem_shape.traversal_stride); // (V,U,O) + auto M_xformed = ceil_div(reverse(take<0,RankT - 1>(problem_shape.shape_C)), L_xformed); + auto N_xformed = problem_shape.shape_C[RankT - 1]; + // shape_B: [K,T,R,S,C], K_xformed: [K,S,R,T] + auto K_xformed = insert<0>( + (reverse(take<1,RankT - 1>(problem_shape.shape_B))), + problem_shape.shape_B[0]); + + return make_shape(M_xformed, N_xformed, K_xformed, L_xformed); + } +} + +// Assuming im2col linearization +// Get problem shape MNKL according to following table: +// | | Fprop | Dgrad | Wgrad | +// | ---- | --------- | -------- | -------- | +// | Shape_M | (Q*P*Z*N) | ([W/V]*[H/U]*[D/O]*N) | (K) | +// | Shape_N | (K) | (C) | (C,S,R,T) | +// | Shape_K | (C,S,R,T) | (K,S,R,T) | (Q*P*Z*N) | +// | Shape_L | _1 | (V*U*O) | _1 | +template +CUTLASS_HOST_DEVICE +constexpr auto +get_linearized_problem_shape_MNKL(ConvProblemShape const& problem_shape) { + + auto [M, N, K, L] = get_transformed_problem_shape_MNKL(problem_shape); + + if constexpr (ConvOp == conv::Operator::kFprop || ConvOp == conv::Operator::kDgrad) { + return cute::make_shape(cute::product(M), N, K, cute::product(L)); + } + else if constexpr (ConvOp == conv::Operator::kWgrad) { + return cute::make_shape(M, N, cute::product(K), L); + } + +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::conv::detail + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/device/conv_universal_adapter.hpp b/include/cutlass/conv/device/conv_universal_adapter.hpp new file mode 100644 index 0000000000..193f8d8854 --- /dev/null +++ b/include/cutlass/conv/device/conv_universal_adapter.hpp @@ -0,0 +1,421 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +// common +#include "cutlass/arch/mma.h" +#include "cutlass/cutlass.h" +#include "cutlass/arch/mma.h" +#include "cutlass/trace.h" +#include "cutlass/cluster_launch.hpp" +#include "cutlass/device_kernel.h" + +#include "cutlass/conv/kernel/conv_universal.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/detail/layout.hpp" +#include "cutlass/cuda_host_adapter.hpp" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::conv::device { + +//////////////////////////////////////////////////////////////////////////////// + +/*! + ConvUniversalAdapter is a stateful, reusable handle built around a kernel + of type cutlass::conv::kernel::ConvUniversal. + + It manages the lifetime of the underlying `kernel::Params` struct, and exposes APIs + to create it from the host facing arguments. For power users, static methods + are exposed that bypass the stateful methods or args->params lowering. +*/ +template +class ConvUniversalAdapter +{ +public: + using ConvKernel = GetUnderlyingKernel_t; + using TileShape = typename ConvKernel::TileShape; + using ElementA = typename ConvKernel::ElementA; + using ElementB = typename ConvKernel::ElementB; + using ElementC = typename ConvKernel::ElementC; + using ElementD = typename ConvKernel::ElementD; + using ElementAccumulator = typename ConvKernel::TiledMma::ValTypeC; + using DispatchPolicy = typename ConvKernel::DispatchPolicy; + using CollectiveMainloop = typename ConvKernel::CollectiveMainloop; + using CollectiveEpilogue = typename ConvKernel::CollectiveEpilogue; + + static bool const kEnableCudaHostAdapter = CUTLASS_ENABLE_CUDA_HOST_ADAPTER; + + // Tease out meta-information about the conv algorithm + static constexpr conv::Operator kConvolutionalOperator = DispatchPolicy::ConvOp; + static constexpr int NumSpatialDimensions = CollectiveMainloop::NumSpatialDimensions; + + // If our TiledMMA's instruction thread layout size is larger than 1, we know its a tensorop! + using OperatorClass = cute::conditional_t< + (cute::size(typename ConvKernel::TiledMma::AtomThrID{}) > 1), + cutlass::arch::OpClassTensorOp, cutlass::arch::OpClassSimt>; + + using ArchTag = typename ConvKernel::ArchTag; + + // Assume TiledMma's ShapeMNK is the same as 2.x's ThreadblockShape + using ThreadblockShape = cutlass::gemm::GemmShape< + cute::size<0>(TileShape{}), + cute::size<1>(TileShape{}), + cute::size<2>(TileShape{})>; + + using ClusterShape = cutlass::gemm::GemmShape< + cute::size<0>(typename ConvKernel::DispatchPolicy::ClusterShape{}), + cute::size<1>(typename ConvKernel::DispatchPolicy::ClusterShape{}), + cute::size<2>(typename ConvKernel::DispatchPolicy::ClusterShape{})>; + + // Instruction shape is easy too, since we get that directly from our TiledMma's atom shape + using InstructionShape = cutlass::gemm::GemmShape< + cute::size<0>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{}), + cute::size<1>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{}), + cute::size<2>(typename CollectiveMainloop::TiledMma::AtomShape_MNK{})>; + + // Legacy: provide a correct warp count, but no reliable warp shape + static int const kThreadCount = ConvKernel::MaxThreadsPerBlock; + + // Warp shape is not a primary API type in 3.x + // But we can best approximate it by inspecting the TiledMma + // For this, we make the assumption that we always have 4 warps along M, and rest along N, none along K + // We also always round up the warp count to 4 if the tiled mma is smaller than 128 threads + static constexpr int WarpsInMma = cute::max(4, CUTE_STATIC_V(cute::size(typename ConvKernel::TiledMma{})) / 32); + static constexpr int WarpsInMmaM = 4; + static constexpr int WarpsInMmaN = cute::ceil_div(WarpsInMma, WarpsInMmaM); + using WarpCount = cutlass::gemm::GemmShape; + using WarpShape = cutlass::gemm::GemmShape< + CUTE_STATIC_V(cute::tile_size<0>(typename CollectiveMainloop::TiledMma{})) / WarpsInMmaM, + CUTE_STATIC_V(cute::tile_size<1>(typename CollectiveMainloop::TiledMma{})) / WarpsInMmaN, + CUTE_STATIC_V(cute::tile_size<2>(typename CollectiveMainloop::TiledMma{}))>; + + static int constexpr kStages = CollectiveMainloop::DispatchPolicy::Stages; + + // Inspect TiledCopy for A and B to compute the alignment size + static int constexpr kAlignmentA = cutlass::detail::get_alignment_count_from_gmem_tiled_copy< + typename CollectiveMainloop::GmemTiledCopyA, ElementA>(); + static int constexpr kAlignmentB = cutlass::detail::get_alignment_count_from_gmem_tiled_copy< + typename CollectiveMainloop::GmemTiledCopyB, ElementB>(); + static int constexpr kAlignmentC = cutlass::detail::get_alignment_count_from_gmem_tiled_copy< + typename CollectiveEpilogue::GmemTiledCopyC, ElementC>(); + static int constexpr kAlignmentD = cutlass::detail::get_alignment_count_from_gmem_tiled_copy< + typename CollectiveEpilogue::GmemTiledCopyD, ElementD>(); + + using EpilogueOutputOp = typename CollectiveEpilogue::ThreadEpilogueOp; + + /// Argument structure: User API + using Arguments = typename ConvKernel::Arguments; + /// Argument structure: Kernel API + using Params = typename ConvKernel::Params; + +private: + + /// Kernel API parameters object + Params params_; + +public: + + /// Access the Params structure + Params const& params() const { + return params_; + } + + /// Determines whether the conv can execute the given problem. + static Status + can_implement(Arguments const& args) { + if (ConvKernel::can_implement(args)) { + return Status::kSuccess; + } + else { + return Status::kInvalid; + } + } + + /// Gets the workspace size + static size_t + get_workspace_size(Arguments const& args) { + size_t workspace_bytes = 0; + CUTLASS_TRACE_HOST(" workspace_bytes: " << workspace_bytes); + + workspace_bytes += ConvKernel::get_workspace_size(args); + return workspace_bytes; + } + + /// Computes the grid shape + static dim3 + get_grid_shape(Arguments const& args, void* workspace = nullptr) { + auto tmp_params = ConvKernel::to_underlying_arguments(args, workspace); + return ConvKernel::get_grid_shape(tmp_params); + } + + /// Computes the grid shape + static dim3 + get_grid_shape(Params const& params) { + return ConvKernel::get_grid_shape(params); + } + + /// Computes the maximum number of active blocks per multiprocessor + static int maximum_active_blocks(int /* smem_capacity */ = -1) { + CUTLASS_TRACE_HOST("ConvUniversal::maximum_active_blocks()"); + int max_active_blocks = -1; + int smem_size = ConvKernel::SharedStorageSize; + + // first, account for dynamic smem capacity if needed + cudaError_t result; + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + result = cudaFuncSetAttribute( + device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaFuncSetAttribute() returned error: " + << cudaGetErrorString(result)); + return -1; + } + } + + // query occupancy after setting smem size + result = cudaOccupancyMaxActiveBlocksPerMultiprocessor( + &max_active_blocks, + device_kernel, + ConvKernel::MaxThreadsPerBlock, + smem_size); + + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST( + " cudaOccupancyMaxActiveBlocksPerMultiprocessor() returned error: " + << cudaGetErrorString(result)); + return -1; + } + + CUTLASS_TRACE_HOST(" max_active_blocks: " << max_active_blocks); + return max_active_blocks; + } + + /// Initializes conv state from arguments. + Status + initialize( + Arguments const& args, + void* workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr) { + + CUTLASS_TRACE_HOST("ConvUniversal::initialize() - workspace " + << workspace << ", stream: " << (stream ? "non-null" : "null")); + + // Initialize the workspace + Status status = ConvKernel::initialize_workspace(args, workspace, stream, cuda_adapter); + if (status != Status::kSuccess) { + return status; + } + + // Initialize the Params structure + params_ = ConvKernel::to_underlying_arguments(args, workspace); + + // Don't set the function attributes - require the CudaHostAdapter to set it. + if constexpr (kEnableCudaHostAdapter) { + CUTLASS_ASSERT(cuda_adapter); + return Status::kSuccess; + } + else { + // account for dynamic smem capacity if needed + int smem_size = ConvKernel::SharedStorageSize; + if (smem_size >= (48 << 10)) { + CUTLASS_TRACE_HOST(" Setting smem size to " << smem_size); + cudaError_t result = cudaFuncSetAttribute( + device_kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + if (cudaSuccess != result) { + result = cudaGetLastError(); // to clear the error bit + CUTLASS_TRACE_HOST(" cudaFuncSetAttribute() returned error: " << cudaGetErrorString(result)); + return Status::kErrorInternal; + } + } + } + return Status::kSuccess; + } + + /// Update API is preserved in 3.0, but does not guarantee a lightweight update of params. + Status + update(Arguments const& args, void* workspace = nullptr) { + CUTLASS_TRACE_HOST("ConvUniversal()::update() - workspace: " << workspace); + + size_t workspace_bytes = get_workspace_size(args); + if (workspace_bytes > 0 && nullptr == workspace) { + return Status::kErrorWorkspaceNull; + } + + params_ = ConvKernel::to_underlying_arguments(args, workspace); + return Status::kSuccess; + } + + /// Primary run() entry point API that is static allowing users to create and manage their own params. + /// Supplied params struct must be construct by calling ConvKernel::to_underling_arguments() + static Status + run(Params& params, cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr, int32_t kernel_index = 0) { + CUTLASS_TRACE_HOST("ConvUniversal::run()"); + dim3 const block = ConvKernel::get_block_shape(); + dim3 const grid = get_grid_shape(params); + + // configure smem size and carveout + int smem_size = ConvKernel::SharedStorageSize; + + Status launch_result; + // Use extended launch API only for mainloops that use it + if constexpr (ConvKernel::ArchTag::kMinComputeCapability >= 90) { + [[maybe_unused]] constexpr bool is_static_1x1x1 = + cute::is_static_v and + cute::size(typename ConvKernel::DispatchPolicy::ClusterShape{}) == 1; + dim3 cluster(cute::size<0>(typename ConvKernel::DispatchPolicy::ClusterShape{}), + cute::size<1>(typename ConvKernel::DispatchPolicy::ClusterShape{}), + cute::size<2>(typename ConvKernel::DispatchPolicy::ClusterShape{})); + void* kernel_params[] = {¶ms}; + if constexpr (kEnableCudaHostAdapter) { + // + // Use the cuda host adapter + // + CUTLASS_ASSERT(cuda_adapter); + if (cuda_adapter) { + + launch_result = cuda_adapter->launch(grid, + cluster, + block, + smem_size, + stream, + kernel_params, + kernel_index); + } + else { + return Status::kErrorInternal; + } + } + else { + CUTLASS_ASSERT(cuda_adapter == nullptr); + void const* kernel = (void const*) device_kernel; + if constexpr (ConvKernel::ArchTag::kMinComputeCapability == 90) { + if constexpr (is_static_1x1x1) { + device_kernel<<>>(params); + launch_result = Status::kSuccess; + } + else { + launch_result = ClusterLauncher::launch( + grid, cluster, block, smem_size, stream, kernel, kernel_params); + } + } + } + } + else { + launch_result = Status::kSuccess; + + if constexpr (kEnableCudaHostAdapter) { + CUTLASS_ASSERT(cuda_adapter); + if (cuda_adapter) { + void* kernel_params[] = {¶ms}; + + launch_result = cuda_adapter->launch( + grid, block, smem_size, stream, kernel_params, 0 + ); + + } + else { + return Status::kErrorInternal; + } + } + else { + CUTLASS_ASSERT(cuda_adapter == nullptr); + device_kernel<<>>(params); + } + } + + cudaError_t result = cudaGetLastError(); + if (cudaSuccess == result && Status::kSuccess == launch_result) { + return Status::kSuccess; + } + else { + CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result); + return Status::kErrorInternal; + } + } + + // + // Non-static launch overloads that first create and set the internal params struct of this kernel handle. + // + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + run( + Arguments const& args, + void* workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr, + int32_t kernel_index = 0 + ) { + Status status = initialize(args, workspace, stream, cuda_adapter); + if (Status::kSuccess == status) { + status = run(params_, stream, cuda_adapter, kernel_index); + } + return status; + } + + /// Launches the kernel after first constructing Params internal state from supplied arguments. + Status + operator()( + Arguments const& args, + void* workspace = nullptr, + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr) { + return run(args, workspace, stream, cuda_adapter); + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + run(cudaStream_t stream = nullptr) { + return run(params_, stream); + } + + /// Overload that allows a user to re-launch the same kernel without updating internal params struct. + Status + operator()(cudaStream_t stream = nullptr) { + return run(params_, stream); + } +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::conv::device + +//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/device/direct_convolution.h b/include/cutlass/conv/device/direct_convolution.h new file mode 100644 index 0000000000..43ab94b5fc --- /dev/null +++ b/include/cutlass/conv/device/direct_convolution.h @@ -0,0 +1,270 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/* \file + \brief Template for device-level Depthwise Convolution +*/ + +#pragma once + +#include + +#include "cutlass/cutlass.h" +#include "cutlass/device_kernel.h" +#include "cutlass/conv/convolution.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace device { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +class DirectConvolution { +public: + + using UnderlyingKernel = DirectConvolutionKernel_; + + using ElementA = typename UnderlyingKernel::ElementA; + using LayoutA = typename UnderlyingKernel::LayoutA; + using ElementB = typename UnderlyingKernel::ElementB; + using LayoutB = typename UnderlyingKernel::LayoutB; + using ElementC = typename UnderlyingKernel::ElementC; + using LayoutC = typename UnderlyingKernel::LayoutC; + using ElementAccumulator = typename UnderlyingKernel::ElementAccumulator; + using ElementCompute = typename UnderlyingKernel::ElementCompute; + using OperatorClass = typename UnderlyingKernel::OperatorClass; + using ArchTag = typename UnderlyingKernel::ArchTag; + using ThreadblockShape = typename UnderlyingKernel::ThreadblockShape; + using WarpShape = typename UnderlyingKernel::WarpShape; + using InstructionShape = typename UnderlyingKernel::InstructionShape; + using ThreadblockSwizzle = typename UnderlyingKernel::ThreadblockSwizzle; + using EpilogueOutputOp = typename UnderlyingKernel::EpilogueOutputOp; + static int const kStages = UnderlyingKernel::kStages; + static int const kConvDim = UnderlyingKernel::kConvDim; + using WarpMmaOperator = typename UnderlyingKernel::WarpMmaOperator; + using ArchMmaOperator = typename UnderlyingKernel::ArchMmaOperator; + using MathOperator = typename UnderlyingKernel::MathOperator; + + static cutlass::conv::Operator const kConvolutionalOperator = UnderlyingKernel::kConvolutionalOperator; + static cutlass::conv::IteratorAlgorithm const kIteratorAlgorithm = UnderlyingKernel::kIteratorAlgorithm; + static cutlass::conv::StrideSupport const kStrideSupport = UnderlyingKernel::kStrideSupport; + static cutlass::conv::GroupMode const kGroupMode = UnderlyingKernel::kGroupMode; + + static int const kWarpCount = + (ThreadblockShape::kM / WarpShape::kM) * + (ThreadblockShape::kN / WarpShape::kN) * + (ThreadblockShape::kK / WarpShape::kK); + + /// Argument structure + using Arguments = typename UnderlyingKernel::Arguments; + + using ReorderKernel = typename UnderlyingKernel::ReorderKernel; + + private: + + /// Kernel parameters object + typename UnderlyingKernel::Params params_; + +public: + + /// Constructs Implicit GEMM + DirectConvolution() { } + + /// Determines whether the Implicit GEMM can execute the given problem. + static Status can_implement(Arguments const &args) { + + // dispatch to iterators + Status status = UnderlyingKernel::Mma::IteratorA::can_implement(args.problem_size); + if (Status::kSuccess != status) { + return status; + } + + status = UnderlyingKernel::Mma::IteratorB::can_implement(args.problem_size); + if (Status::kSuccess != status) { + return status; + } + + if (kGroupMode != conv::GroupMode::kDepthwise) { + return Status::kErrorInvalidProblem; + } + + // C and K should be multiple of groups + if (args.problem_size.K != args.problem_size.groups && + args.problem_size.C != args.problem_size.groups) { + return Status::kErrorInvalidProblem; + } + + + static int const kAlignmentC = UnderlyingKernel::Epilogue::OutputTileIterator::kElementsPerAccess; + if (kConvolutionalOperator == conv::Operator::kFprop) { + if (args.problem_size.K % kAlignmentC) + return Status::kErrorMisalignedOperand; + } else if (kConvolutionalOperator == conv::Operator::kDgrad) { + if (args.problem_size.C % kAlignmentC) + return Status::kErrorMisalignedOperand; + } else if (kConvolutionalOperator == conv::Operator::kWgrad) { + if (args.problem_size.C % kAlignmentC) + return Status::kErrorMisalignedOperand; + } + + // Determine grid shape + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape( + threadblock_swizzle.get_tiled_shape( + kConvolutionalOperator, + args.problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.problem_size.split_k_slices)); + + if (!(grid.y <= std::numeric_limits::max() && + grid.z <= std::numeric_limits::max())) { + + return Status::kErrorInvalidProblem; + } + + return Status::kSuccess; + } + + /// Gets the workspace size + static size_t get_workspace_size(Arguments const &args) { + return 0; + } + + /// Initializes GEMM state from arguments. + Status initialize( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + // initialize the params structure from the arguments + params_ = typename UnderlyingKernel::Params( + args, + static_cast(workspace) + ); + + int smem_size = int(sizeof(typename UnderlyingKernel::SharedStorage)); + + if (smem_size >= (48 << 10)) { + cudaError_t result = cudaFuncSetAttribute(cutlass::Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } + } + + return Status::kSuccess; + } + + /// Initializes GEMM state from arguments. + Status update(Arguments const &args, void *workspace = nullptr) { + + // update the params structure from the arguments + params_.ptr_A = args.ref_A.data(); + params_.ptr_B = args.ref_B.data(); + params_.ptr_C = args.ref_C.data(); + params_.ptr_D = args.ref_D.data(); + params_.output_op = args.output_op; + params_.ptr_reordered_B = args.ref_reordered_B.data(); + params_.semaphore = static_cast(workspace); + + return Status::kSuccess; + } + + /// Runs the kernel using initialized state. + Status run(cudaStream_t stream = nullptr) { + + // Launch reorder kernel + if (params_.ptr_reordered_B != nullptr) { + dim3 grid = ReorderKernel::get_grid_shape(params_); + dim3 block = ReorderKernel::get_block_shape(); + + cutlass::arch::synclog_setup(); + cutlass::Kernel<<>>(params_); + } + + // Launch main kernel + ThreadblockSwizzle threadblock_swizzle; + + dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); + dim3 block(32 * kWarpCount, 1, 1); + + // Dynamic SMEM size based on input params. + int smem_size = int(params_.get_smem_size()); + + // Make sure we can use that much shared memory. + cudaError_t status = + cudaFuncSetAttribute(cutlass::Kernel, cudaFuncAttributeMaxDynamicSharedMemorySize, smem_size); + if (status != cudaSuccess) + return Status::kErrorInternal; + + cutlass::arch::synclog_setup(); + cutlass::Kernel<<>>(params_); + + cudaError_t result = cudaGetLastError(); + + return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; + } + + /// Runs the kernel using initialized state. + Status operator()(cudaStream_t stream = nullptr) { + return run(stream); + } + + /// Runs the kernel using initialized state. + Status operator()( + Arguments const &args, + void *workspace = nullptr, + cudaStream_t stream = nullptr) { + + Status status = initialize(args, workspace, stream); + + if (status == Status::kSuccess) { + status = run(stream); + } + + return status; + } + + int get_smem_size() { return int(params_.get_smem_size()); } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} +} +} + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/device/implicit_gemm_convolution.h b/include/cutlass/conv/device/implicit_gemm_convolution.h index 8e87ec566a..a1cb06e98f 100644 --- a/include/cutlass/conv/device/implicit_gemm_convolution.h +++ b/include/cutlass/conv/device/implicit_gemm_convolution.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -39,6 +39,7 @@ #include "cutlass/cutlass.h" #include "cutlass/device_kernel.h" #include "cutlass/conv/convolution.h" +#include "cutlass/cuda_host_adapter.hpp" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -52,32 +53,35 @@ template class ImplicitGemmConvolution { public: - using ImplicitGemmKernel = ImplicitGemmKernel_; - - using ElementA = typename ImplicitGemmKernel::ElementA; - using LayoutA = typename ImplicitGemmKernel::LayoutA; - using ElementB = typename ImplicitGemmKernel::ElementB; - using LayoutB = typename ImplicitGemmKernel::LayoutB; - using ElementC = typename ImplicitGemmKernel::ElementC; - using LayoutC = typename ImplicitGemmKernel::LayoutC; - using ElementAccumulator = typename ImplicitGemmKernel::ElementAccumulator; - using ElementCompute = typename ImplicitGemmKernel::ElementCompute; - using OperatorClass = typename ImplicitGemmKernel::OperatorClass; - using ArchTag = typename ImplicitGemmKernel::ArchTag; - using ThreadblockShape = typename ImplicitGemmKernel::ThreadblockShape; - using WarpShape = typename ImplicitGemmKernel::WarpShape; - using InstructionShape = typename ImplicitGemmKernel::InstructionShape; - using ThreadblockSwizzle = typename ImplicitGemmKernel::ThreadblockSwizzle; - using EpilogueOutputOp = typename ImplicitGemmKernel::EpilogueOutputOp; - static int const kStages = ImplicitGemmKernel::kStages; - static int const kConvDim = ImplicitGemmKernel::kConvDim; - using WarpMmaOperator = typename ImplicitGemmKernel::WarpMmaOperator; - using ArchMmaOperator = typename ImplicitGemmKernel::ArchMmaOperator; - using MathOperator = typename ImplicitGemmKernel::MathOperator; - - static cutlass::conv::Operator const kConvolutionalOperator = ImplicitGemmKernel::kConvolutionalOperator; - static cutlass::conv::IteratorAlgorithm const kIteratorAlgorithm = ImplicitGemmKernel::kIteratorAlgorithm; - static cutlass::conv::StrideSupport const kStrideSupport = ImplicitGemmKernel::kStrideSupport; + using UnderlyingKernel = GetUnderlyingKernel_t; + + using ElementA = typename UnderlyingKernel::ElementA; + using LayoutA = typename UnderlyingKernel::LayoutA; + using ElementB = typename UnderlyingKernel::ElementB; + using LayoutB = typename UnderlyingKernel::LayoutB; + using ElementC = typename UnderlyingKernel::ElementC; + using LayoutC = typename UnderlyingKernel::LayoutC; + using ElementAccumulator = typename UnderlyingKernel::ElementAccumulator; + using ElementCompute = typename UnderlyingKernel::ElementCompute; + using OperatorClass = typename UnderlyingKernel::OperatorClass; + using ArchTag = typename UnderlyingKernel::ArchTag; + using ThreadblockShape = typename UnderlyingKernel::ThreadblockShape; + using WarpShape = typename UnderlyingKernel::WarpShape; + using InstructionShape = typename UnderlyingKernel::InstructionShape; + using ThreadblockSwizzle = typename UnderlyingKernel::ThreadblockSwizzle; + using EpilogueOutputOp = typename UnderlyingKernel::EpilogueOutputOp; + static int const kStages = UnderlyingKernel::kStages; + static int const kConvDim = UnderlyingKernel::kConvDim; + using WarpMmaOperator = typename UnderlyingKernel::WarpMmaOperator; + using ArchMmaOperator = typename UnderlyingKernel::ArchMmaOperator; + using MathOperator = typename UnderlyingKernel::MathOperator; + + static cutlass::conv::Operator const kConvolutionalOperator = UnderlyingKernel::kConvolutionalOperator; + static cutlass::conv::IteratorAlgorithm const kIteratorAlgorithm = UnderlyingKernel::kIteratorAlgorithm; + static cutlass::conv::StrideSupport const kStrideSupport = UnderlyingKernel::kStrideSupport; + static cutlass::conv::GroupMode const kGroupMode = UnderlyingKernel::kGroupMode; + + static bool const kEnableCudaHostAdapter = CUTLASS_ENABLE_CUDA_HOST_ADAPTER; static int const kWarpCount = (ThreadblockShape::kM / WarpShape::kM) * @@ -85,12 +89,12 @@ class ImplicitGemmConvolution { (ThreadblockShape::kK / WarpShape::kK); /// Argument structure - using Arguments = typename ImplicitGemmKernel::Arguments; + using Arguments = typename UnderlyingKernel::Arguments; private: /// Kernel parameters object - typename ImplicitGemmKernel::Params params_; + typename UnderlyingKernel::Params params_; public: @@ -99,23 +103,56 @@ class ImplicitGemmConvolution { /// Determines whether the Implicit GEMM can execute the given problem. static Status can_implement(Arguments const &args) { - // dispatch to iterators - Status status = ImplicitGemmKernel::Mma::IteratorA::can_implement(args.problem_size); + Status status = UnderlyingKernel::Mma::IteratorA::can_implement(args.problem_size); if (Status::kSuccess != status) { return status; } - status = ImplicitGemmKernel::Mma::IteratorB::can_implement(args.problem_size); + status = UnderlyingKernel::Mma::IteratorB::can_implement(args.problem_size); if (Status::kSuccess != status) { return status; } - static int const kAlignmentC = ImplicitGemmKernel::Epilogue::OutputTileIterator::kElementsPerAccess; + // check group conv constraint + if (args.problem_size.groups != 1) { + if (kGroupMode == conv::GroupMode::kNone) { + return Status::kErrorInvalidProblem; + } + + // C and K should be multiple of groups + if (args.problem_size.K % args.problem_size.groups || + args.problem_size.C % args.problem_size.groups) { + return Status::kErrorInvalidProblem; + } + + // split-k is not supported + if (args.problem_size.split_k_slices != 1) { + return Status::kErrorInvalidProblem; + } + + int k_per_group = args.problem_size.K / args.problem_size.groups; + // k_per_group should be multiple of ThreadblockShape N, one CTA calculate one group + if (kGroupMode == conv::GroupMode::kSingleGroup && k_per_group % ThreadblockShape::kN) { + return Status::kErrorInvalidProblem; + } + // ThreadblockShape::kN should be divisible by k_per_group, one CTA calculate multiple groups + if (kGroupMode == conv::GroupMode::kMultipleGroup && ThreadblockShape::kN % k_per_group) { + return Status::kErrorInvalidProblem; + } + + // current optimized iterator algo only supports SingleGroup mode + if (kIteratorAlgorithm == IteratorAlgorithm::kOptimized && + kGroupMode != conv::GroupMode::kSingleGroup) { + return Status::kErrorInvalidProblem; + } + } + + static int const kAlignmentC = UnderlyingKernel::Epilogue::OutputTileIterator::kElementsPerAccess; if (kConvolutionalOperator == conv::Operator::kFprop) { if (args.problem_size.K % kAlignmentC) return Status::kErrorMisalignedOperand; - } else if (kConvolutionalOperator == conv::Operator::kDgrad) { + } else if (kConvolutionalOperator == conv::Operator::kDgrad || kConvolutionalOperator == conv::Operator::kDeconv) { if (args.problem_size.C % kAlignmentC) return Status::kErrorMisalignedOperand; } else if (kConvolutionalOperator == conv::Operator::kWgrad) { @@ -123,25 +160,15 @@ class ImplicitGemmConvolution { return Status::kErrorMisalignedOperand; } - // check for unsupported problem sizes for strided dgrad implementation - if (kConvolutionalOperator == conv::Operator::kDgrad && + // check for unsupported problem sizes for strided dgrad / deconv implementation + if ((kConvolutionalOperator == conv::Operator::kDgrad || kConvolutionalOperator == conv::Operator::kDeconv) && kStrideSupport == conv::StrideSupport::kStrided) { - - // Unity stride (1x1) is supported by strided dgrad but disabled for performance - // reasons. For unity stride, use strided dgrad optimized unity stride specialization. - // Note that unit tests strided dgrad for unity stride to make sure that strided - // dgrad implemetnation is functionaly sound. - // Strided dgrad implementation also support mixed strides, i.e., (1x2) and (2x1) - if(args.problem_size.stride_h == 1 && args.problem_size.stride_w == 1) { + // split-k (serial or parallel) is not supported for strided dgrad / deconv + if(args.problem_size.split_k_slices > 1 && (args.problem_size.stride().at(args.problem_size.stride().max_dim_index()) > 1)) { return Status::kErrorNotSupported; } - // split-k (serial or parallel) is not supported for strided dgrad - if(args.problem_size.split_k_slices > 1) { - return Status::kErrorNotSupported; - } - - // dilation > {1x1} is not supported for strided dgrad + // dilation > {1x1} is not supported for strided dgrad / deconv if(args.problem_size.dilation_h > 1 || args.problem_size.dilation_w > 1) { return Status::kErrorNotSupported; } @@ -204,7 +231,8 @@ class ImplicitGemmConvolution { Status initialize( Arguments const &args, void *workspace = nullptr, - cudaStream_t stream = nullptr) { + cudaStream_t stream = nullptr, + CudaHostAdapter *cuda_adapter = nullptr) { if (args.problem_size.split_k_slices > 1) { @@ -220,20 +248,26 @@ class ImplicitGemmConvolution { } // initialize the params structure from the arguments - params_ = typename ImplicitGemmKernel::Params( + params_ = typename UnderlyingKernel::Params( args, static_cast(workspace) ); - - int smem_size = int(sizeof(typename ImplicitGemmKernel::SharedStorage)); - if (smem_size >= (48 << 10)) { - cudaError_t result = cudaFuncSetAttribute(cutlass::Kernel, - cudaFuncAttributeMaxDynamicSharedMemorySize, - smem_size); - - if (result != cudaSuccess) { - return Status::kErrorInternal; + if constexpr (kEnableCudaHostAdapter) { + CUTLASS_ASSERT(cuda_adapter); + return Status::kSuccess; + } + else { + int smem_size = int(sizeof(typename UnderlyingKernel::SharedStorage)); + + if (smem_size >= (48 << 10)) { + cudaError_t result = cudaFuncSetAttribute(cutlass::Kernel, + cudaFuncAttributeMaxDynamicSharedMemorySize, + smem_size); + + if (result != cudaSuccess) { + return Status::kErrorInternal; + } } } @@ -255,7 +289,7 @@ class ImplicitGemmConvolution { } /// Runs the kernel using initialized state. - Status run(cudaStream_t stream = nullptr) { + Status run(cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr, int32_t kernel_index = 0) { ThreadblockSwizzle threadblock_swizzle; @@ -263,30 +297,55 @@ class ImplicitGemmConvolution { dim3 grid = threadblock_swizzle.get_grid_shape(params_.grid_tiled_shape); dim3 block(32 * kWarpCount, 1, 1); - int smem_size = int(sizeof(typename ImplicitGemmKernel::SharedStorage)); - - cutlass::Kernel<<>>(params_); + int smem_size = int(sizeof(typename UnderlyingKernel::SharedStorage)); + cutlass::Status launch_result = cutlass::Status::kSuccess ; + + if constexpr (kEnableCudaHostAdapter) { + // + // Use the cuda host adapter + // + CUTLASS_ASSERT(cuda_adapter); + if (cuda_adapter) { + + void* kernel_params[] = {¶ms_}; + launch_result = cuda_adapter->launch( + grid, dim3(1,1,1), block, smem_size, stream, kernel_params, kernel_index + ); + } + else { + launch_result = Status::kErrorInternal; + } + } + else { + cutlass::arch::synclog_setup(); + cutlass::Kernel<<>>(params_); + } cudaError_t result = cudaGetLastError(); - - return result == cudaSuccess ? Status::kSuccess : Status::kErrorInternal; + if (cudaSuccess == result && Status::kSuccess == launch_result) { + return Status::kSuccess; + } + else { + CUTLASS_TRACE_HOST(" Kernel launch failed. Reason: " << result); + return Status::kErrorInternal; + } } /// Runs the kernel using initialized state. - Status operator()(cudaStream_t stream = nullptr) { - return run(stream); + Status operator()(cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr, int32_t kernel_index = 0) { + return run(stream, cuda_adapter, kernel_index); } /// Runs the kernel using initialized state. Status operator()( Arguments const &args, void *workspace = nullptr, - cudaStream_t stream = nullptr) { + cudaStream_t stream = nullptr, CudaHostAdapter *cuda_adapter = nullptr, int32_t kernel_index = 0) { - Status status = initialize(args, workspace, stream); + Status status = initialize(args, workspace, stream, cuda_adapter); if (status == Status::kSuccess) { - status = run(stream); + status = run(stream, cuda_adapter, kernel_index); } return status; diff --git a/include/cutlass/conv/device/implicit_gemm_convolution_fusion.h b/include/cutlass/conv/device/implicit_gemm_convolution_fusion.h index b0c590add3..265156cc5b 100644 --- a/include/cutlass/conv/device/implicit_gemm_convolution_fusion.h +++ b/include/cutlass/conv/device/implicit_gemm_convolution_fusion.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -231,6 +231,7 @@ class ImplicitGemmConvolutionFusion { int smem_size = int(sizeof(typename ImplicitGemmFusionKernel::SharedStorage)); + cutlass::arch::synclog_setup(); cutlass::Kernel<<>>(params_); cudaError_t result = cudaGetLastError(); diff --git a/include/cutlass/conv/dispatch_policy.hpp b/include/cutlass/conv/dispatch_policy.hpp new file mode 100644 index 0000000000..b8b5eb2bff --- /dev/null +++ b/include/cutlass/conv/dispatch_policy.hpp @@ -0,0 +1,90 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/conv/convolution.h" +#include "cutlass/epilogue/thread/activation.h" +#include "cutlass/arch/arch.h" + +#include "cute/layout.hpp" +#include "cute/numeric/integral_constant.hpp" + +#include "cutlass/gemm/dispatch_policy.hpp" + +////////////////////////////////////////////////////////////////////////////// + +////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::conv { + +////////////////////////////////////////////////////////////////////////////// + +// +// Policies for categorical dispatch of mainloop against kernel grid schedules +// +struct KernelImplicitTmaWarpSpecializedSm90 : cutlass::gemm::KernelTmaWarpSpecialized { }; +struct KernelImplicitTmaWarpSpecializedSm90Cooperative { }; +struct KernelImplicitTmaWarpSpecializedSm90Pingpong { }; + +// +// Collective Mainloop Policies +// + +// n-buffer in smem (Hopper TMA), pipelined with Hopper GMMA and TMA, static schedule between TMA and GMMA +// for fprop +template< + conv::Operator ConvOp_, + int Stages_, + int NumSpatialDimensions_, + class ClusterShape_ = cute::Shape,cute::C<1>,cute::C<1>>, + class KernelSchedule = KernelImplicitTmaWarpSpecializedSm90, + int PipelineAsyncMmaStages_ = 1 +> +struct MainloopSm90TmaGmmaWarpSpecializedImplicitGemm { + static constexpr int Stages = Stages_; + static constexpr int NumSpatialDimensions = NumSpatialDimensions_; + static constexpr Operator ConvOp = ConvOp_; + static constexpr int PipelineAsyncMmaStages = PipelineAsyncMmaStages_; + using ClusterShape = ClusterShape_; + using ArchTag = arch::Sm90; + using Schedule = KernelSchedule; + + static_assert(NumSpatialDimensions >= 1); + static_assert(! (cute::is_same_v || + cute::is_same_v), + "Persistent schedules not support for conv yet."); +}; + +////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::conv + +////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/kernel/conv_universal.hpp b/include/cutlass/conv/kernel/conv_universal.hpp new file mode 100644 index 0000000000..23ccea2f8f --- /dev/null +++ b/include/cutlass/conv/kernel/conv_universal.hpp @@ -0,0 +1,65 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/conv/convnd_problem_shape.hpp" +#include "cutlass/detail/dependent_false.hpp" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::conv::kernel { + +//////////////////////////////////////////////////////////////////////////////// + +/* + * Stateless universal device CONV kernel type that treats CONV as + * a composition of a collective mainloop and a collective epilogue. +**/ +template < + class ProblemShape_, + class CollectiveMainloop_, + class CollectiveEpilogue_, + class TileSchedulerTag_ = void, + class Enable = void +> +class ConvUniversal { + static_assert(cutlass::detail::dependent_false, + "Could not find a valid specialization at the kernel layer to dispatch against."); +}; + +//////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::conv::kernel + +//////////////////////////////////////////////////////////////////////////////// + +#include "cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp" +//////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/kernel/default_conv2d.h b/include/cutlass/conv/kernel/default_conv2d.h index 00cff2378c..79bedb2c84 100644 --- a/include/cutlass/conv/kernel/default_conv2d.h +++ b/include/cutlass/conv/kernel/default_conv2d.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -106,6 +106,56 @@ struct DefaultConvEpilogue< }; ///////////////////////////////////////////////////////////////////////////////////////////////// +template < + typename ArchTag, + typename Shape, + typename WarpMmaSimt, + typename ElementOutput, + typename ElementTensor, + typename ElementVector, + typename OutputOp, + int ElementsPerAccess, + typename PermuteDLayout = layout::NoPermute, + conv::StrideSupport StrideSupport = conv::StrideSupport::kUnity, + int Rank = 4 +> +struct DefaultConvEpilogueWithBroadcastSimt { + using Epilogue = typename epilogue::threadblock::DefaultEpilogueWithBroadcastSimt< + Shape, + WarpMmaSimt, + ElementOutput, + ElementTensor, + ElementVector, + OutputOp, + ElementsPerAccess, + false, + PermuteDLayout, + StrideSupport, + Rank + >::Epilogue; +}; + +template < + typename ArchTag, + typename Shape, + typename WarpMmaSimt, + typename ElementOutput, + typename ElementTensor, + typename ElementVector, + typename OutputOp, + int ElementsPerAccess +> +struct DefaultConvEpilogueWithBroadcastSimtStridedDgrad { + using Epilogue = typename epilogue::threadblock::DefaultEpilogueWithBroadcastSimtStridedDgrad< + Shape, + WarpMmaSimt, + ElementOutput, + ElementTensor, + ElementVector, + OutputOp, + ElementsPerAccess + >::Epilogue; +}; template < typename ArchTag, diff --git a/include/cutlass/conv/kernel/default_conv2d_dgrad.h b/include/cutlass/conv/kernel/default_conv2d_dgrad.h index 547068713b..c5a8b1315e 100644 --- a/include/cutlass/conv/kernel/default_conv2d_dgrad.h +++ b/include/cutlass/conv/kernel/default_conv2d_dgrad.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/include/cutlass/conv/kernel/default_conv2d_fprop.h b/include/cutlass/conv/kernel/default_conv2d_fprop.h index 859f23969a..9fbd97e585 100644 --- a/include/cutlass/conv/kernel/default_conv2d_fprop.h +++ b/include/cutlass/conv/kernel/default_conv2d_fprop.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -76,7 +76,7 @@ template < int Stages, typename MathOperatorTag, conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, - conv::StrideSupport StrideSupport = StrideSupport::kStrided, + conv::StrideSupport StrideSupport = StrideSupport::kUnity, /// Access granularity of A matrix in units of elements int AlignmentA = 128 / cutlass::sizeof_bits::value, /// Access granularity of B matrix in units of elements @@ -327,7 +327,6 @@ struct DefaultConv2dFprop < >; }; - ///////////////////////////////////////////////////////////////////////////////////////////////// /// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and two stage @@ -1167,7 +1166,11 @@ struct DefaultConv2dFprop < WarpMmaTensorOp, kPartitionsK, EpilogueOutputOp, - EpilogueOutputOp::kCount + EpilogueOutputOp::kCount, + false, + layout::NoPermute, + StrideSupport, + 4 >::Epilogue; // Define the kernel @@ -1628,7 +1631,11 @@ struct DefaultConv2dFprop < ThreadblockShape, WarpMmaSimtOp, EpilogueOutputOp, - EpilogueOutputOp::kCount + EpilogueOutputOp::kCount, + false, + layout::NoPermute, + StrideSupport, + 4 >::Epilogue; // Define the kernel @@ -1741,7 +1748,11 @@ struct DefaultConv2dFprop < ThreadblockShape, WarpMmaSimtOp, EpilogueOutputOp, - EpilogueOutputOp::kCount + EpilogueOutputOp::kCount, + false, + layout::NoPermute, + StrideSupport, + 4 >::Epilogue; // Define the kernel @@ -1751,7 +1762,6 @@ struct DefaultConv2dFprop < ThreadblockSwizzle, conv::Operator::kFprop >; - }; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -1853,7 +1863,11 @@ struct DefaultConv2dFprop < ThreadblockShape, WarpMmaSimtOp, EpilogueOutputOp, - EpilogueOutputOp::kCount + EpilogueOutputOp::kCount, + false, + layout::NoPermute, + StrideSupport, + 4 >::Epilogue; // Define the kernel @@ -1967,7 +1981,11 @@ struct DefaultConv2dFprop < ThreadblockShape, WarpMmaSimtOp, EpilogueOutputOp, - EpilogueOutputOp::kCount + EpilogueOutputOp::kCount, + false, + layout::NoPermute, + StrideSupport, + 4 >::Epilogue; // Define the kernel diff --git a/include/cutlass/conv/kernel/default_conv2d_fprop_fusion.h b/include/cutlass/conv/kernel/default_conv2d_fprop_fusion.h index 4fc2200a5d..8589ace029 100644 --- a/include/cutlass/conv/kernel/default_conv2d_fprop_fusion.h +++ b/include/cutlass/conv/kernel/default_conv2d_fprop_fusion.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -45,8 +45,8 @@ #include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h" #include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h" #include "cutlass/conv/threadblock/predicated_scale_bias_vector_access_iterator.h" -#include "cutlass/conv/threadblock/regular_scale_bias_vector_access_iterator.h" -#include "cutlass/conv/warp/conv2d_fprop_scale_bias_iterator.h" +#include "cutlass/transform/threadblock/regular_scale_bias_vector_access_iterator.h" +#include "cutlass/gemm/warp/scale_bias_tile_iterator.h" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -76,7 +76,7 @@ template < int Stages, typename MathOperatorTag, conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, - conv::StrideSupport StrideSupport = StrideSupport::kStrided + conv::StrideSupport StrideSupport = StrideSupport::kUnity > struct DefaultConv2dFpropFusion; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -161,7 +161,7 @@ struct DefaultConv2dFpropFusion < LayoutScaleBias>; using SmemIteratorScaleBias = - cutlass::conv::threadblock::RegularScaleBiasVectorAccessIterator< + cutlass::transform::threadblock::RegularScaleBiasVectorAccessIterator< cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias, LayoutScaleBias>; @@ -172,7 +172,7 @@ struct DefaultConv2dFpropFusion < static int const kThreadCount = 32; // Warp-level iterators to load scale and bias vectors - using WarpIteratorScaleBias = cutlass::conv::warp::WarpIteratorScaleBias< + using WarpIteratorScaleBias = cutlass::gemm::warp::ScaleBiasTileIterator< MatrixShape, ElementScaleBias, LayoutScaleBias, MatrixShape, typename WarpMmaTensorOp::IteratorA::Base::Policy, kThreadCount, @@ -296,7 +296,7 @@ struct DefaultConv2dFpropFusion < LayoutScaleBias>; using SmemIteratorScaleBias = - cutlass::conv::threadblock::RegularScaleBiasVectorAccessIterator< + cutlass::transform::threadblock::RegularScaleBiasVectorAccessIterator< cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias, LayoutScaleBias>; @@ -307,7 +307,7 @@ struct DefaultConv2dFpropFusion < static int const kThreadCount = 32; // Warp-level iterators to load scale and bias vectors - using WarpIteratorScaleBias = cutlass::conv::warp::WarpIteratorScaleBias< + using WarpIteratorScaleBias = cutlass::gemm::warp::ScaleBiasTileIterator< MatrixShape, ElementScaleBias, LayoutScaleBias, MatrixShape, typename WarpMmaTensorOp::IteratorA::Base::Policy, kThreadCount, diff --git a/include/cutlass/conv/kernel/default_conv2d_fprop_with_absmax.h b/include/cutlass/conv/kernel/default_conv2d_fprop_with_absmax.h new file mode 100644 index 0000000000..76bc12886c --- /dev/null +++ b/include/cutlass/conv/kernel/default_conv2d_fprop_with_absmax.h @@ -0,0 +1,127 @@ +/*************************************************************************************************** + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Defines a default configuration for convolution with absolute maximum calculation. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/conv/kernel/default_conv2d_fprop.h" +#include "cutlass/conv/kernel/implicit_gemm_convolution_with_absmax.h" + +#include "cutlass/epilogue/threadblock/default_epilogue_with_absmax.h" +#include "cutlass/epilogue/threadblock/epilogue_with_absmax.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, + conv::StrideSupport StrideSupport = StrideSupport::kUnity, + /// Access granularity of A matrix in units of elements + int AlignmentA = 128 / cutlass::sizeof_bits::value, + /// Access granularity of B matrix in units of elements + int AlignmentB = 128 / cutlass::sizeof_bits::value +> +struct DefaultConv2dFpropWithAbsMax { + + using ImplicitGemmBase = typename DefaultConv2dFprop< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm, + StrideSupport, + AlignmentA, + AlignmentB + >::Kernel; + + // Define epilogue + using Epilogue = typename cutlass::epilogue::threadblock::DefaultEpilogueWithAbsMax< + typename ImplicitGemmBase::Epilogue::Shape, + typename ImplicitGemmBase::Epilogue::WarpMmaOperator, + ImplicitGemmBase::Epilogue::kPartitionsK, + ElementC, + typename EpilogueOutputOp::ElementAuxOutput, + ElementC, + EpilogueOutputOp, + ImplicitGemmBase::Epilogue::kElementsPerAccess + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionWithAbsMax< + typename ImplicitGemmBase::Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h b/include/cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h index 187c46f93a..0825789ced 100644 --- a/include/cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h +++ b/include/cutlass/conv/kernel/default_conv2d_fprop_with_broadcast.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -31,7 +31,7 @@ /*! \file \brief - Defines a GEMM with Reduction based on an existing UniversalGemm kernel. + Defines a GEMM with Broadcast based on an existing UniversalGemm kernel. */ @@ -71,7 +71,7 @@ template < int Stages, typename MathOperatorTag, conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, - conv::StrideSupport StrideSupport = StrideSupport::kStrided, + conv::StrideSupport StrideSupport = StrideSupport::kUnity, /// Access granularity of A matrix in units of elements int AlignmentA = 128 / cutlass::sizeof_bits::value, /// Access granularity of B matrix in units of elements @@ -99,7 +99,7 @@ struct DefaultConv2dFpropWithBroadcast { AlignmentB >::Kernel; - // Replace epilogue + // Define epilogue using Epilogue = typename cutlass::conv::kernel::detail::DefaultConvEpilogueWithBroadcastTensorOp< ArchTag, typename ImplicitGemmBase::Epilogue::Shape, @@ -107,7 +107,98 @@ struct DefaultConv2dFpropWithBroadcast { ImplicitGemmBase::Epilogue::kPartitionsK, ElementC, typename EpilogueOutputOp::ElementT, + typename EpilogueOutputOp::ElementVector, + EpilogueOutputOp, + ImplicitGemmBase::Epilogue::kElementsPerAccess + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionWithFusedEpilogue< + typename ImplicitGemmBase::Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// OpClassSimt convolutions +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm, +/// multi-stage pipeline, and FFMA-based mainloop for SM80 + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dFpropWithBroadcast < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm, + StrideSupport, + AlignmentA, + AlignmentB +> { + + using ImplicitGemmBase = typename DefaultConv2dFprop< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm, + StrideSupport, + AlignmentA, + AlignmentB + >::Kernel; + + // Define epilogue + using Epilogue = typename cutlass::conv::kernel::detail::DefaultConvEpilogueWithBroadcastSimt< + ArchTag, + typename ImplicitGemmBase::Epilogue::Shape, + typename ImplicitGemmBase::Epilogue::WarpMmaOperator, ElementC, + typename EpilogueOutputOp::ElementT, + typename EpilogueOutputOp::ElementVector, EpilogueOutputOp, ImplicitGemmBase::Epilogue::kElementsPerAccess >::Epilogue; diff --git a/include/cutlass/conv/kernel/default_conv2d_fprop_with_reduction.h b/include/cutlass/conv/kernel/default_conv2d_fprop_with_reduction.h index 82bbd252a1..e6e8a82209 100644 --- a/include/cutlass/conv/kernel/default_conv2d_fprop_with_reduction.h +++ b/include/cutlass/conv/kernel/default_conv2d_fprop_with_reduction.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -72,7 +72,7 @@ template < int Stages, typename MathOperatorTag, conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, - conv::StrideSupport StrideSupport = StrideSupport::kStrided, + conv::StrideSupport StrideSupport = StrideSupport::kUnity, /// Access granularity of A matrix in units of elements int AlignmentA = 128 / cutlass::sizeof_bits::value, /// Access granularity of B matrix in units of elements @@ -100,7 +100,7 @@ struct DefaultConv2dFpropWithReduction { AlignmentB >::Kernel; - // Replace epilogue + // Define epilogue using Epilogue = typename cutlass::conv::kernel::detail::DefaultConvEpilogueWithReductionTensorOp< ArchTag, typename ImplicitGemmBase::Epilogue::Shape, diff --git a/include/cutlass/conv/kernel/default_conv2d_group_fprop.h b/include/cutlass/conv/kernel/default_conv2d_group_fprop.h new file mode 100644 index 0000000000..e2deaf6fe2 --- /dev/null +++ b/include/cutlass/conv/kernel/default_conv2d_group_fprop.h @@ -0,0 +1,622 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped + matrix multiply-add with the appropriate threadblock-scoped epilogue. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/conv/kernel/default_conv2d.h" + +#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h" +#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_fixed_channels.h" +#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_few_channels.h" + +#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h" +#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_fixed_channels.h" +#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_few_channels.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Conv2dGroupFprop +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::GroupMode GroupMode, + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, + conv::StrideSupport StrideSupport = StrideSupport::kUnity, + /// Access granularity of A matrix in units of elements + int AlignmentA = 128 / cutlass::sizeof_bits::value, + /// Access granularity of B matrix in units of elements + int AlignmentB = 128 / cutlass::sizeof_bits::value +> struct DefaultConv2dGroupFprop; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// OpClassTensorOp convolutions +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dGroupFprop specialization for Analytic IteratorAlgorithm and multistage +/// pipeline that supports all GroupMode. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::GroupMode GroupMode, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dGroupFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + GroupMode, + IteratorAlgorithm::kAnalytic, + StrideSupport, + AlignmentA, + AlignmentB +> { + + static_assert(platform::is_same::value, + "Current group conv only support NHWC layout"); + static_assert(platform::is_same::value, + "Current group conv only support NHWC layout"); + static_assert(platform::is_same::value, + "Current group conv only support NHWC layout"); + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; + using IteratorA = + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, LayoutA, + ThreadMapA, + AccessTypeA, + GroupMode + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB, + AccessTypeB, + GroupMode + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * AlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + CacheOpB, + MmaPolicy, + Stages + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, + WarpMmaTensorOp, + kPartitionsK, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop, + Conv2dProblemSize, + GroupMode + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dGroupFprop specialization for Analytic IteratorAlgorithm and +/// 2 stage pipeline that supports all GroupMode. + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + conv::GroupMode GroupMode, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dGroupFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + GroupMode, + IteratorAlgorithm::kAnalytic, + StrideSupport, + AlignmentA, + AlignmentB +> { + + static_assert(platform::is_same::value, + "Current group conv only support NHWC layout"); + static_assert(platform::is_same::value, + "Current group conv only support NHWC layout"); + static_assert(platform::is_same::value, + "Current group conv only support NHWC layout"); + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, LayoutA, + ThreadMapA, + AccessTypeA, + GroupMode + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB, + AccessTypeB, + GroupMode + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + // Define the epilogue + using Epilogue = typename detail::DefaultConvEpilogue< + ArchTag, + ThreadblockShape, + WarpMmaTensorOp, + kPartitionsK, + EpilogueOutputOp + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop, + Conv2dProblemSize, + GroupMode + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dGroupFprop specialization for Optimized IteratorAlgorithm and multistage +/// pipeline that supports GroupMode::kSingleGroup. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dGroupFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + GroupMode::kSingleGroup, + IteratorAlgorithm::kOptimized, + StrideSupport, + AlignmentA, + AlignmentB +> { + + static_assert(platform::is_same::value, + "Current group conv only support NHWC layout"); + static_assert(platform::is_same::value, + "Current group conv only support NHWC layout"); + static_assert(platform::is_same::value, + "Current group conv only support NHWC layout"); + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; + using IteratorA = + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, LayoutA, + ThreadMapA, + AccessTypeA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB, + AccessTypeB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * AlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + CacheOpB, + MmaPolicy, + Stages + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, + WarpMmaTensorOp, + kPartitionsK, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop, + Conv2dProblemSize, + GroupMode::kSingleGroup + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv2dGroupFprop specialization for Optimized IteratorAlgorithm and +/// 2 stage pipeline that supports GroupMode::kSingleGroup. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB +> +struct DefaultConv2dGroupFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + GroupMode::kSingleGroup, + IteratorAlgorithm::kOptimized, + StrideSupport, + AlignmentA, + AlignmentB +> { + + static_assert(platform::is_same::value, + "Current group conv only support NHWC layout"); + static_assert(platform::is_same::value, + "Current group conv only support NHWC layout"); + static_assert(platform::is_same::value, + "Current group conv only support NHWC layout"); + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using AccessTypeA = cutlass::AlignedArray; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + LayoutA, + ThreadMapA, + AccessTypeA + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + LayoutB, + ThreadMapB, + AccessTypeB + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + static const int kPartitionsK = ThreadblockShape::kK / WarpShape::kK; + + // Define the epilogue + using Epilogue = typename detail::DefaultConvEpilogue< + ArchTag, + ThreadblockShape, + WarpMmaTensorOp, + kPartitionsK, + EpilogueOutputOp + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop, + Conv2dProblemSize, + GroupMode::kSingleGroup + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/kernel/default_conv2d_wgrad.h b/include/cutlass/conv/kernel/default_conv2d_wgrad.h index c5e5b3c5b5..d0e52dfe34 100644 --- a/include/cutlass/conv/kernel/default_conv2d_wgrad.h +++ b/include/cutlass/conv/kernel/default_conv2d_wgrad.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/include/cutlass/conv/kernel/default_conv2d_wgrad_fusion.h b/include/cutlass/conv/kernel/default_conv2d_wgrad_fusion.h index e43e02cde5..110e07db9c 100644 --- a/include/cutlass/conv/kernel/default_conv2d_wgrad_fusion.h +++ b/include/cutlass/conv/kernel/default_conv2d_wgrad_fusion.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/include/cutlass/conv/kernel/default_conv3d_dgrad.h b/include/cutlass/conv/kernel/default_conv3d_dgrad.h index 8f408ef04e..cb50ba49b8 100644 --- a/include/cutlass/conv/kernel/default_conv3d_dgrad.h +++ b/include/cutlass/conv/kernel/default_conv3d_dgrad.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -293,6 +293,439 @@ struct DefaultConv3dDgrad < }; +///////////////////////////////////////////////////////////////////////////////////////////////// +// OpClassSimt convolutions +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag +> +struct DefaultConv3dDgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + conv::StrideSupport::kStrided +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv3dDgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + conv::StrideSupport::kStrided + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv3dDgradFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDgrad, + Conv3dProblemSize + >; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv3dDgrad specialization for Optimized IteratorAlgorithm, +/// multi-stage pipeline, and FFMA-based mainloop for SM80 + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag +> +struct DefaultConv3dDgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport::kUnity +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv3dDgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + StrideSupport::kUnity + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv3dDgradFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + ThreadMapB + // ThreadMapB, + // StrideSupport::kUnity + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDgrad, + Conv3dProblemSize + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag +> +struct DefaultConv3dDgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + conv::StrideSupport::kStrided +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + // cutlass::conv::threadblock::TileIteratorStridedDgrad< + cutlass::conv::threadblock::Conv3dDgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + conv::StrideSupport::kStrided + // > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + // cutlass::conv::threadblock::TileIteratorStridedDgrad< + cutlass::conv::threadblock::Conv3dDgradFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB + // > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDgrad, + Conv3dProblemSize + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv3dDgrad specialization for Optimized IteratorAlgorithm, +/// 2 stage pipeline, and FFMA-based mainloop for SM50 +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag +> +struct DefaultConv3dDgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport::kUnity +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + // cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv3dDgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + StrideSupport::kUnity + // > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + // cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv3dDgradFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + ThreadMapB + // ThreadMapB, + // StrideSupport::kUnity + // > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDgrad, + Conv3dProblemSize + >; + +}; + ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace kernel diff --git a/include/cutlass/conv/kernel/default_conv3d_fprop.h b/include/cutlass/conv/kernel/default_conv3d_fprop.h index 7b20ae344a..41fdd64a5e 100644 --- a/include/cutlass/conv/kernel/default_conv3d_fprop.h +++ b/include/cutlass/conv/kernel/default_conv3d_fprop.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -54,7 +54,7 @@ namespace conv { namespace kernel { ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Defines a kernel for Conv2dFprop +/// Defines a kernel for Conv3dFprop template < typename ElementA, typename LayoutA, @@ -73,7 +73,7 @@ template < int Stages, typename MathOperatorTag, conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, - conv::StrideSupport StrideSupport = StrideSupport::kStrided + conv::StrideSupport StrideSupport = StrideSupport::kUnity > struct DefaultConv3dFprop; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -94,7 +94,8 @@ template < typename InstructionShape, typename EpilogueOutputOp, typename ThreadblockSwizzle, - typename MathOperatorTag + typename MathOperatorTag, + conv::StrideSupport StrideSupport > struct DefaultConv3dFprop < ElementA, @@ -113,7 +114,8 @@ struct DefaultConv3dFprop < ThreadblockSwizzle, 2, MathOperatorTag, - IteratorAlgorithm::kAnalytic + IteratorAlgorithm::kAnalytic, + StrideSupport > { // Define the core components from GEMM @@ -185,7 +187,7 @@ struct DefaultConv3dFprop < ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Defines a kernel for Conv2dFprop specialization for Analytic IteratorAlgorithm and multistage +/// Defines a kernel for Conv3dFprop specialization for Analytic IteratorAlgorithm and multistage // pipeline. template < typename ElementA, @@ -202,7 +204,8 @@ template < typename EpilogueOutputOp, typename ThreadblockSwizzle, int Stages, - typename MathOperatorTag + typename MathOperatorTag, + conv::StrideSupport StrideSupport > struct DefaultConv3dFprop < ElementA, @@ -221,7 +224,8 @@ struct DefaultConv3dFprop < ThreadblockSwizzle, Stages, MathOperatorTag, - IteratorAlgorithm::kAnalytic + IteratorAlgorithm::kAnalytic, + StrideSupport > { // Define the core components from GEMM @@ -306,7 +310,8 @@ template < typename InstructionShape, typename EpilogueOutputOp, typename ThreadblockSwizzle, - typename MathOperatorTag + typename MathOperatorTag, + conv::StrideSupport StrideSupport > struct DefaultConv3dFprop < ElementA, @@ -325,7 +330,8 @@ struct DefaultConv3dFprop < ThreadblockSwizzle, 2, MathOperatorTag, - IteratorAlgorithm::kOptimized + IteratorAlgorithm::kOptimized, + StrideSupport > { // Define the core components from GEMM @@ -416,7 +422,8 @@ template < typename EpilogueOutputOp, typename ThreadblockSwizzle, int Stages, - typename MathOperatorTag + typename MathOperatorTag, + conv::StrideSupport StrideSupport > struct DefaultConv3dFprop < ElementA, @@ -435,7 +442,8 @@ struct DefaultConv3dFprop < ThreadblockSwizzle, Stages, MathOperatorTag, - IteratorAlgorithm::kOptimized + IteratorAlgorithm::kOptimized, + StrideSupport > { // Define the core components from GEMM @@ -492,7 +500,465 @@ struct DefaultConv3dFprop < WarpMmaTensorOp, 1, EpilogueOutputOp, - EpilogueOutputOp::kCount + EpilogueOutputOp::kCount, + false, + layout::NoPermute, + StrideSupport, + 5 + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop, + Conv3dProblemSize + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// OpClassSimt convolutions +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Conv3dFprop specialization for Analytic IteratorAlgorithm, +/// multi-stage pipeline, and FFMA-based mainloop for SM80 + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::StrideSupport StrideSupport +> +struct DefaultConv3dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + StrideSupport +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv3dFpropActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount, + false, + layout::NoPermute, + StrideSupport, + 5 + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop, + Conv3dProblemSize + >; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv3dFprop specialization for Optimized IteratorAlgorithm, +/// multi-stage pipeline, and FFMA-based mainloop for SM80 + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::StrideSupport StrideSupport +> +struct DefaultConv3dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv3dFpropActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + LayoutA, + ThreadMapA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + LayoutB, + ThreadMapB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount, + false, + layout::NoPermute, + StrideSupport, + 5 + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop, + Conv3dProblemSize + >; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv3dFprop specialization for Analytic IteratorAlgorithm, +/// 2 stage pipeline, and FFMA-based mainloop for SM50 +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + conv::StrideSupport StrideSupport +> +struct DefaultConv3dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + StrideSupport +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv3dFpropActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount, + false, + layout::NoPermute, + StrideSupport, + 5 + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop, + Conv3dProblemSize + >; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv3dFprop specialization for Optimized IteratorAlgorithm, +/// 2 stage pipeline, and FFMA-based mainloop for SM50 +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + conv::StrideSupport StrideSupport +> +struct DefaultConv3dFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv3dFpropActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + LayoutA, + ThreadMapA + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + LayoutB, + ThreadMapB + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount, + false, + layout::NoPermute, + StrideSupport, + 5 >::Epilogue; // Define the kernel diff --git a/include/cutlass/conv/kernel/default_conv3d_fprop_fusion.h b/include/cutlass/conv/kernel/default_conv3d_fprop_fusion.h new file mode 100644 index 0000000000..d0457d572e --- /dev/null +++ b/include/cutlass/conv/kernel/default_conv3d_fprop_fusion.h @@ -0,0 +1,360 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level fused activation's scale+bias+relu and implicit GEMM convolution + definitions that combine threadblock-scoped matrix multiply-add with the + appropriate threadblock-scoped epilogue. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/conv/kernel/default_conv2d.h" + +#include "cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h" +#include "cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h" +#include "cutlass/conv/threadblock/predicated_scale_bias_vector_access_iterator.h" +#include "cutlass/transform/threadblock/regular_scale_bias_vector_access_iterator.h" +#include "cutlass/gemm/warp/scale_bias_tile_iterator.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for fused batch norm and Conv3dFprop +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementScaleBias, + typename LayoutScaleBias, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, + conv::StrideSupport StrideSupport = StrideSupport::kUnity +> struct DefaultConv3dFpropFusion; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// OpClassTensorOp convolutions +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv3dFprop specialzation for Analytic IteratorAlgorithm and multistage +/// pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementScaleBias, + typename LayoutScaleBias, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag +> +struct DefaultConv3dFpropFusion < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementScaleBias, + LayoutScaleBias, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv3dFpropActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + /// Define iterators over tiles from scale/bias vectors + using IteratorScaleBias = + cutlass::conv::threadblock::PredicatedScaleBiasVectorAccessIterator< + cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias, + LayoutScaleBias>; + + using SmemIteratorScaleBias = + cutlass::transform::threadblock::RegularScaleBiasVectorAccessIterator< + cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias, + LayoutScaleBias>; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + static int const kThreadCount = 32; + + // Warp-level iterators to load scale and bias vectors + using WarpIteratorScaleBias = cutlass::gemm::warp::ScaleBiasTileIterator< + MatrixShape, ElementScaleBias, + LayoutScaleBias, MatrixShape, + typename WarpMmaTensorOp::IteratorA::Base::Policy, kThreadCount, + MmaCore::WarpCount::kK>; + + // Define the Mma + using Mma = threadblock::ImplicitGemmFpropFusionMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Global, + IteratorScaleBias, + SmemIteratorScaleBias, + arch::CacheOperation::Always, + MmaPolicy, + WarpIteratorScaleBias, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, + WarpMmaTensorOp, + 1, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionFusion< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop, + Conv3dProblemSize + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv3dFprop specialzation for Optimzed IteratorAlgorithm and +/// multistage pipeline. +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementScaleBias, + typename LayoutScaleBias, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag +> +struct DefaultConv3dFpropFusion < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementScaleBias, + LayoutScaleBias, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassTensorOp, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassTensorOp, + Stages, MathOperatorTag + >; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv3dFpropActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + LayoutA, + ThreadMapA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + LayoutB, + ThreadMapB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + /// Define iterators over tiles from scale/bias vectors + using IteratorScaleBias = + cutlass::conv::threadblock::PredicatedScaleBiasVectorAccessIterator< + cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias, + LayoutScaleBias>; + + using SmemIteratorScaleBias = + cutlass::transform::threadblock::RegularScaleBiasVectorAccessIterator< + cutlass::MatrixShape<1, ThreadblockShape::kK>, ElementScaleBias, + LayoutScaleBias>; + + // Warp-level GEMM components + using WarpMmaTensorOp = typename MmaCore::MmaTensorOp; + using MmaPolicy = typename MmaCore::MmaPolicy; + + static int const kThreadCount = 32; + + // Warp-level iterators to load scale and bias vectors + using WarpIteratorScaleBias = cutlass::gemm::warp::ScaleBiasTileIterator< + MatrixShape, ElementScaleBias, + LayoutScaleBias, MatrixShape, + typename WarpMmaTensorOp::IteratorA::Base::Policy, kThreadCount, + MmaCore::WarpCount::kK>; + + // Define the Mma + using Mma = threadblock::ImplicitGemmFpropFusionMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Global, + IteratorScaleBias, + SmemIteratorScaleBias, + arch::CacheOperation::Always, + MmaPolicy, + WarpIteratorScaleBias, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueTensorOp< + ThreadblockShape, + WarpMmaTensorOp, + 1, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionFusion< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop, + Conv3dProblemSize + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/kernel/default_conv3d_fprop_with_broadcast.h b/include/cutlass/conv/kernel/default_conv3d_fprop_with_broadcast.h new file mode 100644 index 0000000000..0fc291e605 --- /dev/null +++ b/include/cutlass/conv/kernel/default_conv3d_fprop_with_broadcast.h @@ -0,0 +1,222 @@ +/*************************************************************************************************** + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Defines a GEMM with Broadcast based on an existing UniversalGemm kernel. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/conv/kernel/default_conv3d_fprop.h" +#include "cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h" + +#include "cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h" +#include "cutlass/epilogue/threadblock/epilogue_with_broadcast.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, + conv::StrideSupport StrideSupport = StrideSupport::kUnity, + /// Access granularity of A matrix in units of elements + int AlignmentA = 128 / cutlass::sizeof_bits::value, + /// Access granularity of B matrix in units of elements + int AlignmentB = 128 / cutlass::sizeof_bits::value +> +struct DefaultConv3dFpropWithBroadcast { + + using ImplicitGemmBase = typename DefaultConv3dFprop< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm, + StrideSupport + >::Kernel; + + // Define epilogue + using Epilogue = typename cutlass::conv::kernel::detail::DefaultConvEpilogueWithBroadcastTensorOp< + ArchTag, + typename ImplicitGemmBase::Epilogue::Shape, + typename ImplicitGemmBase::Epilogue::WarpMmaOperator, + ImplicitGemmBase::Epilogue::kPartitionsK, + ElementC, + typename EpilogueOutputOp::ElementT, + typename EpilogueOutputOp::ElementVector, + EpilogueOutputOp, + ImplicitGemmBase::Epilogue::kElementsPerAccess + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionWithFusedEpilogue< + typename ImplicitGemmBase::Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop, + Conv3dProblemSize + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// OpClassSimt convolutions +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Conv3dFprop specialization for Analytic IteratorAlgorithm, +/// multi-stage pipeline, and FFMA-based mainloop for SM80 + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB +> +struct DefaultConv3dFpropWithBroadcast < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm, + StrideSupport, + AlignmentA, + AlignmentB +> { + + using ImplicitGemmBase = typename DefaultConv3dFprop< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm, + StrideSupport + >::Kernel; + + // Define epilogue + using Epilogue = typename cutlass::conv::kernel::detail::DefaultConvEpilogueWithBroadcastSimt< + ArchTag, + typename ImplicitGemmBase::Epilogue::Shape, + typename ImplicitGemmBase::Epilogue::WarpMmaOperator, + ElementC, + typename EpilogueOutputOp::ElementT, + typename EpilogueOutputOp::ElementVector, + EpilogueOutputOp, + ImplicitGemmBase::Epilogue::kElementsPerAccess, + layout::NoPermute, + StrideSupport, + 5 + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionWithFusedEpilogue< + typename ImplicitGemmBase::Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop, + Conv3dProblemSize + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/kernel/default_conv3d_wgrad.h b/include/cutlass/conv/kernel/default_conv3d_wgrad.h index 46dc392043..4ed5e0c1bf 100644 --- a/include/cutlass/conv/kernel/default_conv3d_wgrad.h +++ b/include/cutlass/conv/kernel/default_conv3d_wgrad.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -53,7 +53,7 @@ namespace kernel { ///////////////////////////////////////////////////////////////////////////////////////////////// -/// Defines a kernel for Conv2dWgrad +/// Defines a kernel for Conv3dWgrad template < typename ElementA, typename LayoutA, @@ -500,6 +500,433 @@ struct DefaultConv3dWgrad < Conv3dProblemSize >; }; + + +///////////////////////////////////////////////////////////////////////////////////////////////// +// OpClassSimt convolutions +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Conv3dWgrad specialization for Analytic IteratorAlgorithm, +/// multi-stage pipeline, and FFMA-based mainloop for SM80 + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag +> +struct DefaultConv3dWgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv3dWgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv3dWgradActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kWgrad, + Conv3dProblemSize + >; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv3dWgrad specialization for Optimized IteratorAlgorithm, +/// multi-stage pipeline, and FFMA-based mainloop for SM80 + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag +> +struct DefaultConv3dWgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv3dWgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv3dWgradActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + ThreadMapB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kWgrad, + Conv3dProblemSize + >; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv3dWgrad specialization for Analytic IteratorAlgorithm, +/// 2 stage pipeline, and FFMA-based mainloop for SM50 +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag +> +struct DefaultConv3dWgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kAnalytic +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv3dWgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv3dWgradActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kWgrad, + Conv3dProblemSize + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Conv3dWgrad specialization for Optimized IteratorAlgorithm, +/// 2 stage pipeline, and FFMA-based mainloop for SM50 +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag +> +struct DefaultConv3dWgrad < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kOptimized +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::ColumnMajor, + ElementB, layout::RowMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv3dWgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv3dWgradActivationTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + ThreadMapB + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kWgrad, + Conv3dProblemSize + >; + +}; ///////////////////////////////////////////////////////////////////////////////////////////////// } // namespace kernel diff --git a/include/cutlass/conv/kernel/default_deconv2d.h b/include/cutlass/conv/kernel/default_deconv2d.h new file mode 100644 index 0000000000..4db152cd7a --- /dev/null +++ b/include/cutlass/conv/kernel/default_deconv2d.h @@ -0,0 +1,999 @@ +/*************************************************************************************************** + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped + matrix multiply-add with the appropriate threadblock-scoped epilogue. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/conv/kernel/default_conv2d.h" + +#include "cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_optimized.h" +#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h" +#include "cutlass/conv/threadblock/conv2d_tile_iterator.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Deconv2d +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, + conv::StrideSupport StrideSupport = StrideSupport::kStrided, + /// Access granularity of A matrix in units of elements + int AlignmentA = 128 / cutlass::sizeof_bits::value, + /// Access granularity of B matrix in units of elements + int AlignmentB = 128 / cutlass::sizeof_bits::value +> struct DefaultDeconv2d; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// OpClassSimt convolutions +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Deconv2d specialization for Analytic IteratorAlgorithm, +/// multi-stage pipeline, and FFMA-based mainloop for SM80 + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> +struct DefaultDeconv2d < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + conv::StrideSupport::kUnity, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + conv::StrideSupport::kUnity + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB, + cutlass::AlignedArray, + conv::GroupMode::kNone, + true /*IsDeconv*/ + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount, + false, + layout::NoPermute, + StrideSupport::kStrided, + 4 + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv + >; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> +struct DefaultDeconv2d < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + conv::StrideSupport::kStrided, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + conv::StrideSupport::kStrided + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB, + cutlass::AlignedArray, + conv::GroupMode::kNone, + true /*IsDeconv*/ + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimtStridedDgrad< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv + >; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Deconv2d specialization for Optimized IteratorAlgorithm, +/// multi-stage pipeline, and FFMA-based mainloop for SM80 + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> +struct DefaultDeconv2d < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport::kUnity, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + StrideSupport::kUnity + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB, + cutlass::AlignedArray, + true /*IsDeconv*/ + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount, + false, + layout::NoPermute, + StrideSupport::kStrided, + 4 + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> +struct DefaultDeconv2d < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + conv::StrideSupport::kStrided, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + conv::StrideSupport::kStrided + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB, + cutlass::AlignedArray, + true /*IsDeconv*/ + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimtStridedDgrad< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv + >; + +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Deconv2d specialization for Analytic IteratorAlgorithm, +/// 2 stage pipeline, and FFMA-based mainloop for SM50 +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> +struct DefaultDeconv2d < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + conv::StrideSupport::kUnity, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + conv::StrideSupport::kUnity + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB, + cutlass::AlignedArray, + conv::GroupMode::kNone, + true /*IsDeconv*/ + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount, + false, + layout::NoPermute, + StrideSupport::kStrided, + 4 + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv + >; + +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> +struct DefaultDeconv2d < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + conv::StrideSupport::kStrided, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::TileIteratorStridedDgrad< + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + conv::StrideSupport::kStrided + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::TileIteratorStridedDgrad< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB, + cutlass::AlignedArray, + conv::GroupMode::kNone, + true /*IsDeconv*/ + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimtStridedDgrad< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Deconv2d specialization for Optimized IteratorAlgorithm, +/// 2 stage pipeline, and FFMA-based mainloop for SM50 +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> +struct DefaultDeconv2d < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport::kUnity, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + StrideSupport::kUnity + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB, + cutlass::AlignedArray, + true /*IsDeconv*/ + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount, + false, + layout::NoPermute, + StrideSupport::kStrided, + 4 + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv + >; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + int AlignmentA, + int AlignmentB +> +struct DefaultDeconv2d < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + conv::StrideSupport::kStrided, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::TileIteratorStridedDgrad< + cutlass::conv::threadblock::Conv2dDgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + conv::StrideSupport::kStrided + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::TileIteratorStridedDgrad< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB, + cutlass::AlignedArray, + true /*IsDeconv*/ + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimtStridedDgrad< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionStridedDgrad< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv + >; + +}; + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/kernel/default_deconv2d_with_broadcast.h b/include/cutlass/conv/kernel/default_deconv2d_with_broadcast.h new file mode 100644 index 0000000000..d11432ed39 --- /dev/null +++ b/include/cutlass/conv/kernel/default_deconv2d_with_broadcast.h @@ -0,0 +1,305 @@ +/*************************************************************************************************** + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Defines a GEMM with Broadcast based on an existing UniversalGemm kernel. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/conv/kernel/default_deconv2d.h" +#include "cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h" + +#include "cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h" +#include "cutlass/epilogue/threadblock/epilogue_with_broadcast.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, + conv::StrideSupport StrideSupport = StrideSupport::kStrided, + /// Access granularity of A matrix in units of elements + int AlignmentA = 128 / cutlass::sizeof_bits::value, + /// Access granularity of B matrix in units of elements + int AlignmentB = 128 / cutlass::sizeof_bits::value +> +struct DefaultDeconv2dWithBroadcast { + + using ImplicitGemmBase = typename DefaultDeconv2d< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm, + StrideSupport, + AlignmentA, + AlignmentB + >::Kernel; + + // Define epilogue + using Epilogue = typename cutlass::conv::kernel::detail::DefaultConvEpilogueWithBroadcastTensorOp< + ArchTag, + typename ImplicitGemmBase::Epilogue::Shape, + typename ImplicitGemmBase::Epilogue::WarpMmaOperator, + ImplicitGemmBase::Epilogue::kPartitionsK, + ElementC, + typename EpilogueOutputOp::ElementT, + typename EpilogueOutputOp::ElementVector, + EpilogueOutputOp, + ImplicitGemmBase::Epilogue::kElementsPerAccess + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionWithFusedEpilogue< + typename ImplicitGemmBase::Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// OpClassSimt convolutions +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Deconv2d specialization, +/// multi-stage pipeline, and FFMA-based mainloop for SM80 + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm, + int AlignmentA, + int AlignmentB +> +struct DefaultDeconv2dWithBroadcast < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm, + conv::StrideSupport::kUnity, + AlignmentA, + AlignmentB +> { + + using ImplicitGemmBase = typename DefaultDeconv2d< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm, + conv::StrideSupport::kUnity, + AlignmentA, + AlignmentB + >::Kernel; + + // Define epilogue + using Epilogue = typename cutlass::conv::kernel::detail::DefaultConvEpilogueWithBroadcastSimt< + ArchTag, + typename ImplicitGemmBase::Epilogue::Shape, + typename ImplicitGemmBase::Epilogue::WarpMmaOperator, + ElementC, + typename EpilogueOutputOp::ElementT, + typename EpilogueOutputOp::ElementVector, + EpilogueOutputOp, + ImplicitGemmBase::Epilogue::kElementsPerAccess + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionWithFusedEpilogue< + typename ImplicitGemmBase::Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm, + int AlignmentA, + int AlignmentB +> +struct DefaultDeconv2dWithBroadcast < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm, + conv::StrideSupport::kStrided, + AlignmentA, + AlignmentB +> { + + using ImplicitGemmBase = typename DefaultDeconv2d< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm, + conv::StrideSupport::kStrided, + AlignmentA, + AlignmentB + >::Kernel; + + // Define epilogue + using Epilogue = typename cutlass::conv::kernel::detail::DefaultConvEpilogueWithBroadcastSimtStridedDgrad< + ArchTag, + typename ImplicitGemmBase::Epilogue::Shape, + typename ImplicitGemmBase::Epilogue::WarpMmaOperator, + ElementC, + typename EpilogueOutputOp::ElementT, + typename EpilogueOutputOp::ElementVector, + EpilogueOutputOp, + ImplicitGemmBase::Epilogue::kElementsPerAccess + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionWithFusedEpilogue< + typename ImplicitGemmBase::Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/kernel/default_deconv3d.h b/include/cutlass/conv/kernel/default_deconv3d.h new file mode 100644 index 0000000000..70800c7af7 --- /dev/null +++ b/include/cutlass/conv/kernel/default_deconv3d.h @@ -0,0 +1,541 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level implicit GEMM convolution definitions combine threadblock-scoped + matrix multiply-add with the appropriate threadblock-scoped epilogue. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/conv/kernel/default_conv2d.h" + +#include "cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_optimized.h" +#include "cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h" + +#include "cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv2d_tile_iterator.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Deconv3d +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, + conv::StrideSupport StrideSupport = StrideSupport::kStrided +> struct DefaultDeconv3d; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// OpClassSimt convolutions +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag +> +struct DefaultDeconv3d < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + conv::StrideSupport::kStrided +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv3dDgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + conv::StrideSupport::kStrided + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB, + true /*IsDeconv*/ + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount, + false, + layout::NoPermute, + StrideSupport::kStrided, + 5 + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv, + Conv3dProblemSize + >; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Deconv3d specialization for Optimized IteratorAlgorithm, +/// multi-stage pipeline, and FFMA-based mainloop for SM80 + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag +> +struct DefaultDeconv3d < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport::kUnity +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + Stages, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::Conv3dDgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + StrideSupport::kUnity + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + LayoutB, + ThreadMapB, + true /*IsDeconv*/ + // ThreadMapB, + // StrideSupport::kUnity + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmMultistage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + arch::CacheOperation::Always, + IteratorB, + SmemIteratorB, + arch::CacheOperation::Always, + MmaPolicy, + Stages + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount, + false, + layout::NoPermute, + StrideSupport::kStrided, + 5 + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv, + Conv3dProblemSize + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag +> +struct DefaultDeconv3d < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kAnalytic, + conv::StrideSupport::kStrided +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + // cutlass::conv::threadblock::TileIteratorStridedDgrad< + cutlass::conv::threadblock::Conv3dDgradOutputGradientTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + conv::StrideSupport::kStrided + // > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + // cutlass::conv::threadblock::TileIteratorStridedDgrad< + cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, + ThreadMapB, + true /*IsDeconv*/ + // > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount, + false, + layout::NoPermute, + StrideSupport::kStrided, + 5 + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv, + Conv3dProblemSize + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Defines a kernel for Deconv3d specialization for Optimized IteratorAlgorithm, +/// 2 stage pipeline, and FFMA-based mainloop for SM50 +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag +> +struct DefaultDeconv3d < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport::kUnity +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::gemm::threadblock::DefaultMmaCore< + ThreadblockShape, WarpShape, InstructionShape, ElementA, layout::RowMajor, + ElementB, layout::ColumnMajor, ElementAccumulator, layout::RowMajor, arch::OpClassSimt, + 2, MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + // cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv3dDgradOutputGradientTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementA, + ThreadMapA, + StrideSupport::kUnity + // > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using IteratorB = + // cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv3dFpropFilterTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, + LayoutB, + ThreadMapB, + true /*IsDeconv*/ + // ThreadMapB, + // StrideSupport::kUnity + // > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::ImplicitGemmPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount, + false, + layout::NoPermute, + StrideSupport::kStrided, + 5 + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv, + Conv3dProblemSize + >; + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + diff --git a/include/cutlass/conv/kernel/default_deconv3d_with_broadcast.h b/include/cutlass/conv/kernel/default_deconv3d_with_broadcast.h new file mode 100644 index 0000000000..affe7a06f4 --- /dev/null +++ b/include/cutlass/conv/kernel/default_deconv3d_with_broadcast.h @@ -0,0 +1,309 @@ +/*************************************************************************************************** + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Defines a GEMM with Broadcast based on an existing UniversalGemm kernel. + +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/conv/kernel/default_deconv3d.h" +#include "cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h" + +#include "cutlass/epilogue/threadblock/default_epilogue_with_broadcast.h" +#include "cutlass/epilogue/threadblock/epilogue_with_broadcast.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kOptimized, + conv::StrideSupport StrideSupport = StrideSupport::kStrided, + /// Access granularity of A matrix in units of elements + int AlignmentA = 128 / cutlass::sizeof_bits::value, + /// Access granularity of B matrix in units of elements + int AlignmentB = 128 / cutlass::sizeof_bits::value +> +struct DefaultDeconv3dWithBroadcast { + + using ImplicitGemmBase = typename DefaultDeconv3d< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + OperatorClass, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm, + StrideSupport + >::Kernel; + + // Define epilogue + using Epilogue = typename cutlass::conv::kernel::detail::DefaultConvEpilogueWithBroadcastTensorOp< + ArchTag, + typename ImplicitGemmBase::Epilogue::Shape, + typename ImplicitGemmBase::Epilogue::WarpMmaOperator, + ImplicitGemmBase::Epilogue::kPartitionsK, + ElementC, + typename EpilogueOutputOp::ElementT, + typename EpilogueOutputOp::ElementVector, + EpilogueOutputOp, + ImplicitGemmBase::Epilogue::kElementsPerAccess + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionWithFusedEpilogue< + typename ImplicitGemmBase::Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv, + Conv3dProblemSize + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// OpClassSimt convolutions +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Deconv3d specialization for Analytic IteratorAlgorithm, +/// multi-stage pipeline, and FFMA-based mainloop for SM80 + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm, + int AlignmentA, + int AlignmentB +> +struct DefaultDeconv3dWithBroadcast < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm, + conv::StrideSupport::kUnity, + AlignmentA, + AlignmentB +> { + + using ImplicitGemmBase = typename DefaultDeconv3d< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm, + conv::StrideSupport::kUnity + >::Kernel; + + // Define epilogue + using Epilogue = typename cutlass::conv::kernel::detail::DefaultConvEpilogueWithBroadcastSimt< + ArchTag, + typename ImplicitGemmBase::Epilogue::Shape, + typename ImplicitGemmBase::Epilogue::WarpMmaOperator, + ElementC, + typename EpilogueOutputOp::ElementT, + typename EpilogueOutputOp::ElementVector, + EpilogueOutputOp, + ImplicitGemmBase::Epilogue::kElementsPerAccess, + layout::NoPermute, + StrideSupport::kStrided, + 5 + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionWithFusedEpilogue< + typename ImplicitGemmBase::Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv, + Conv3dProblemSize + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm, + int AlignmentA, + int AlignmentB +> +struct DefaultDeconv3dWithBroadcast < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm, + conv::StrideSupport::kStrided, + AlignmentA, + AlignmentB +> { + + using ImplicitGemmBase = typename DefaultDeconv3d< + ElementA, LayoutA, + ElementB, LayoutB, + ElementC, LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm, + conv::StrideSupport::kStrided + >::Kernel; + + // Define epilogue + using Epilogue = typename cutlass::conv::kernel::detail::DefaultConvEpilogueWithBroadcastSimt< + ArchTag, + typename ImplicitGemmBase::Epilogue::Shape, + typename ImplicitGemmBase::Epilogue::WarpMmaOperator, + ElementC, + typename EpilogueOutputOp::ElementT, + typename EpilogueOutputOp::ElementVector, + EpilogueOutputOp, + ImplicitGemmBase::Epilogue::kElementsPerAccess, + layout::NoPermute, + StrideSupport::kStrided, + 5 + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolutionWithFusedEpilogue< + typename ImplicitGemmBase::Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kDeconv, + Conv3dProblemSize + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/kernel/default_depthwise_fprop.h b/include/cutlass/conv/kernel/default_depthwise_fprop.h new file mode 100644 index 0000000000..aa4f2c359c --- /dev/null +++ b/include/cutlass/conv/kernel/default_depthwise_fprop.h @@ -0,0 +1,588 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief + Default kernel-level Depthwise implicit GEMM convolution definitions combine threadblock-scoped + matrix multiply-add with the appropriate threadblock-scoped epilogue. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/conv/kernel/default_conv2d.h" +#include "cutlass/conv/kernel/direct_convolution.h" + +#include "cutlass/conv/threadblock/depthwise_mma_core_with_lane_access_size.h" + +#include "cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h" +#include "cutlass/conv/threadblock/depthwise_fprop_pipelined.h" + +// Direct Conv Related Header files +#include "cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_optimized.h" +#include "cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_fixed_stride_dilation.h" + +#include "cutlass/conv/threadblock/depthwise_fprop_filter_tile_access_iterator_direct_conv_optimized.h" +#include "cutlass/conv/threadblock/depthwise_fprop_direct_conv_multistage.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for DepthwiseFprop +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic, + conv::StrideSupport StrideSupport = StrideSupport::kUnity, + /// Access granularity of A matrix in units of elements + int AlignmentA = 128 / cutlass::sizeof_bits::value, + /// Access granularity of B matrix in units of elements + int AlignmentB = cutlass::sizeof_bits::value / cutlass::sizeof_bits::value +> struct DefaultDepthwiseFprop; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for DepthwiseFprop with direct convolution algorithm +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename OperatorClass, + typename ArchTag, + typename ThreadblockShape, + typename ThreadBlockOutputShape, + typename FilterShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic, + conv::StrideSupport StrideSupport = StrideSupport::kUnity, + // MatrixShape + typename StrideShape = cutlass::MatrixShape<-1, -1>, + // MatrixShape< Height, Width> + typename DilationShape = cutlass::MatrixShape<-1, -1>, + /// Access granularity of A matrix in units of elements + int AlignmentA = 128 / cutlass::sizeof_bits::value, + /// Access granularity of B matrix in units of elements + int AlignmentB = 128 / cutlass::sizeof_bits::value +> struct DefaultDepthwiseDirect2dConvFprop; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// OpClassSimt convolutions +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Depthwise specialization for Analytic IteratorAlgorithm +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + int AlignmentA, + int AlignmentB +> +struct DefaultDepthwiseFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + 2, + MathOperatorTag, // cutlass::arch::OpMultiplyAdd + IteratorAlgorithm::kAnalytic, + StrideSupport, + AlignmentA, + AlignmentB +> { + + // Define the core components from GEMM + using MmaCore = typename cutlass::conv::threadblock::DepthwiseMmaCoreWithLaneAccessSize< + ThreadblockShape, + WarpShape, + InstructionShape, + ElementA, + layout::RowMajor, + ElementB, + layout::ColumnMajor, + ElementAccumulator, + layout::RowMajor, + arch::OpClassSimt, + 128, + sizeof_bits::value, + 2, + MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropActivationTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementA, LayoutA, + ThreadMapA + > + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::TileIterator< + cutlass::conv::threadblock::Conv2dFpropFilterTileAccessIteratorAnalytic< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB, + AccessTypeB, + cutlass::conv::GroupMode::kDepthwise + > + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + + // Define the Mma + using Mma = threadblock::DepthwiseFpropPipelined< + ThreadblockShape, + IteratorA, + SmemIteratorA, + IteratorB, + SmemIteratorB, + ElementC, + LayoutC, + MmaPolicy + >; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultEpilogueSimt< + ThreadblockShape, + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount + >::Epilogue; + + // Define the kernel + using Kernel = cutlass::conv::kernel::ImplicitGemmConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop, + Conv2dProblemSize, + cutlass::conv::GroupMode::kDepthwise + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Depthwise specialization for direct 2d conv implementation, +/// multiple stage pipeline, and SIMT-based mainloop +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename ThreadBlockOutputShape, + typename FilterShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + typename StrideShape, + typename DilationShape, + int AlignmentA, + int AlignmentB +> +struct DefaultDepthwiseDirect2dConvFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + ThreadBlockOutputShape, + FilterShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kOptimized, + StrideSupport, + StrideShape, + DilationShape, + AlignmentA, + AlignmentB +> { + // One warp handles the entrie groups per cta. + static_assert(ThreadblockShape::kN == WarpShape::kN, + "ThreadblockShape::kN should be same as WarpShape::kN "); + static_assert(ThreadblockShape::kK == FilterShape::kCount && WarpShape::kK == FilterShape::kCount, + "ThreadblockShape::kK and WarpShape::kK should be same as filter size"); + static_assert(ThreadblockShape::kM % WarpShape::kM == 0, + "ThreadblockShape::kM must be divisible by WarpShape shape::kM"); + static_assert(ThreadBlockOutputShape::kN, "ThreadBlockOutputShape::kN should be 1"); + + // Define the core components from GEMM + using MmaCore = typename cutlass::conv::threadblock::DepthwiseDirectConvMmaCoreWithLaneAccessSize< + ThreadblockShape, + ThreadBlockOutputShape, + FilterShape, + WarpShape, + InstructionShape, + ElementA, + layout::RowMajor, + ElementB, + layout::ColumnMajor, + ElementAccumulator, + layout::RowMajor, + arch::OpClassSimt, + 128, + 128, + Stages, + MathOperatorTag>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::DepthwiseFpropActivationDirect2dConvTileAccessIteratorOptimized< + cutlass::MatrixShape, // < outputShape:KMNK, groups per cta> + ThreadBlockOutputShape, + ElementA, LayoutA, + ThreadMapA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + using ThreadOutputShape = typename MmaCore::ThreadOutputShape; + static cutlass::arch::CacheOperation::Kind const CacheOpA = + ((sizeof_bits::value * AlignmentA) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * AlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultDirectConvEpilogueSimt< + ThreadblockShape, // < outputShape:KMNK, groups per cta> + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount, + ThreadOutputShape, + ThreadBlockOutputShape + >::Epilogue; + + // Define the Mma + using Mma = threadblock::DepthwiseFpropDirectConvMultipleStage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + CacheOpA, + IteratorB, + SmemIteratorB, + CacheOpB, + MmaPolicy, + Stages, + Epilogue + >; + + // Define the kernel + using Kernel = cutlass::conv::kernel::DirectConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop, + Conv2dProblemSize, + cutlass::conv::GroupMode::kDepthwise, + ThreadBlockOutputShape + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Defines a kernel for Depthwise specialization for direct 2d conv implementation, +/// multiple stage pipeline, and SIMT-based mainloop +template < + typename ElementA, + typename LayoutA, + typename ElementB, + typename LayoutB, + typename ElementC, + typename LayoutC, + typename ElementAccumulator, + typename ArchTag, + typename ThreadblockShape, + typename ThreadBlockOutputShape, + typename FilterShape, + typename WarpShape, + typename InstructionShape, + typename EpilogueOutputOp, + typename ThreadblockSwizzle, + int Stages, + typename MathOperatorTag, + conv::StrideSupport StrideSupport, + typename StrideShape, + typename DilationShape, + int AlignmentA, + int AlignmentB +> +struct DefaultDepthwiseDirect2dConvFprop < + ElementA, + LayoutA, + ElementB, + LayoutB, + ElementC, + LayoutC, + ElementAccumulator, + arch::OpClassSimt, + ArchTag, + ThreadblockShape, + ThreadBlockOutputShape, + FilterShape, + WarpShape, + InstructionShape, + EpilogueOutputOp, + ThreadblockSwizzle, + Stages, + MathOperatorTag, + IteratorAlgorithm::kFixedStrideDilation, + StrideSupport, + StrideShape, + DilationShape, + AlignmentA, + AlignmentB +> { + + + + // One warp handles the entrie groups per cta. + static_assert(ThreadblockShape::kN == WarpShape::kN, + "ThreadblockShape::kN should be same as WarpShape::kN "); + static_assert(ThreadblockShape::kK == FilterShape::kCount && WarpShape::kK == FilterShape::kCount, + "ThreadblockShape::kK and WarpShape::kK should be same as filter size"); + static_assert(ThreadblockShape::kM % WarpShape::kM == 0, + "ThreadblockShape::kM must be divisible by WarpShape shape::kM"); + static_assert(ThreadBlockOutputShape::kN, "ThreadBlockOutputShape::kN should be 1"); + + static_assert(StrideShape::kRow >= 0 && StrideShape::kColumn >= 0, "Stride should be fixed"); + static_assert(DilationShape::kRow >= 0 && DilationShape::kColumn >= 0, "Stride should be fixed"); + + // Activations loaded by threadblock + static int const ActivationShapeH = (ThreadBlockOutputShape::kH - 1) * StrideShape::kRow + + (FilterShape::kRow - 1) * DilationShape::kRow + 1; + + static int const ActivationShapeW = (ThreadBlockOutputShape::kW - 1) * StrideShape::kColumn + + (FilterShape::kColumn - 1) * DilationShape::kColumn + 1; + + using ActivationShape = + cutlass::conv::TensorNHWCShape<1, ActivationShapeH, ActivationShapeW, ThreadblockShape::kN >; + + // Define the core components from GEMM + using MmaCore = typename cutlass::conv::threadblock::DepthwiseDirectConvMmaCoreWithLaneAccessSize< + ThreadblockShape, + ThreadBlockOutputShape, + FilterShape, + WarpShape, + InstructionShape, + ElementA, + layout::RowMajor, + ElementB, + layout::ColumnMajor, + ElementAccumulator, + layout::RowMajor, + arch::OpClassSimt, + 128, + 128, + Stages, + MathOperatorTag, + IteratorAlgorithm::kFixedStrideDilation, + StrideShape, + DilationShape, + ActivationShape>; + + // Define iterators over tiles from the A operand + using ThreadMapA = typename MmaCore::IteratorThreadMapA; + using IteratorA = + cutlass::conv::threadblock::DepthwiseFpropActivationDirect2dConvTileAccessIteratorFixedStrideDilation< + cutlass::MatrixShape, // < outputShape:KMNK, groups per cta> + ThreadBlockOutputShape, + StrideShape, + DilationShape, + ActivationShape, + ElementA, LayoutA, + ThreadMapA + >; + + using SmemIteratorA = typename MmaCore::SmemIteratorA; + + // Define iterators over tiles from the B operand + using ThreadMapB = typename MmaCore::IteratorThreadMapB; + using AccessTypeB = cutlass::AlignedArray; + using IteratorB = + cutlass::conv::threadblock::DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized< + cutlass::MatrixShape, + ElementB, LayoutB, + ThreadMapB + >; + + using SmemIteratorB = typename MmaCore::SmemIteratorB; + + // Warp-level GEMM components + using WarpMmaSimtOp = typename MmaCore::MmaWarpSimt; + using MmaPolicy = typename MmaCore::MmaPolicy; + using ThreadOutputShape = typename MmaCore::ThreadOutputShape; + static cutlass::arch::CacheOperation::Kind const CacheOpA = + ((sizeof_bits::value * AlignmentA) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + static cutlass::arch::CacheOperation::Kind const CacheOpB = + ((sizeof_bits::value * AlignmentB) == 128) + ? cutlass::arch::CacheOperation::Global + : cutlass::arch::CacheOperation::Always; + + // Define the epilogue + using Epilogue = typename epilogue::threadblock::DefaultDirectConvEpilogueSimt< + ThreadblockShape, // < outputShape:KMNK, groups per cta> + WarpMmaSimtOp, + EpilogueOutputOp, + EpilogueOutputOp::kCount, + ThreadOutputShape, + ThreadBlockOutputShape + >::Epilogue; + + // Define the Mma + using Mma = threadblock::DepthwiseFpropDirectConvMultipleStage< + ThreadblockShape, + IteratorA, + SmemIteratorA, + CacheOpA, + IteratorB, + SmemIteratorB, + CacheOpB, + MmaPolicy, + Stages, + Epilogue, + IteratorAlgorithm::kFixedStrideDilation + >; + + // Define the kernel + using Kernel = cutlass::conv::kernel::DirectConvolution< + Mma, + Epilogue, + ThreadblockSwizzle, + conv::Operator::kFprop, + Conv2dProblemSize, + cutlass::conv::GroupMode::kDepthwise, + ThreadBlockOutputShape + >; +}; + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/kernel/direct_convolution.h b/include/cutlass/conv/kernel/direct_convolution.h new file mode 100644 index 0000000000..d4e98fa49e --- /dev/null +++ b/include/cutlass/conv/kernel/direct_convolution.h @@ -0,0 +1,506 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a multi-staged Depthwise Convolution kernel. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/platform/platform.h" +#include "cutlass/semaphore.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/conv3d_problem_size.h" +#include "cutlass/epilogue/threadblock/output_iterator_parameter.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Parameters structure +template > ///! OutputShape per ThreadBlock +struct DirectConvolutionParams { + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using ThreadBlockOutputShape = ThreadBlockOutputShape_; + static Operator const kConvolutionalOperator = ConvOperator; + using ConvProblemSize = ConvProblemSize_; + using Arguments = Arguments_; + using ConvOutputIteratorParameter = ConvOutputIteratorParameter_; + + using ThreadblockShape = typename Mma::Shape; + static IteratorAlgorithm const kIteratorAlgorithm = Mma::IteratorA::kIteratorAlgorithm; + static conv::GroupMode const kGroupMode = GroupMode_; + static int const kStages = Mma::kStages; + + ConvProblemSize problem_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + gemm::GemmCoord implicit_gemm_problem_size; + int swizzle_log_tile; + int smem_size_; + + int gemm_k_iterations; + int gemm_k_iterations_per_channel; + typename Mma::IteratorA::Params iterator_A; + typename Mma::IteratorA::Element const *ptr_A; + typename Mma::IteratorB::Params iterator_B; + typename Mma::IteratorB::Element const *ptr_B; + typename Mma::IteratorB::Element *ptr_reordered_B; + typename Epilogue::OutputTileIterator::Params iterator_C; + typename Epilogue::OutputTileIterator::Element *ptr_C; + typename Epilogue::OutputTileIterator::Params iterator_D; + typename Epilogue::OutputTileIterator::Element *ptr_D; + typename EpilogueOutputOp::Params output_op; + int *semaphore; + SplitKMode split_k_mode; + int split_k_slices; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + DirectConvolutionParams() : swizzle_log_tile(0), gemm_k_iterations(0) {} + + /// + CUTLASS_HOST_DEVICE + DirectConvolutionParams(Arguments const &args, int *semaphore = nullptr) + : problem_size(args.problem_size), + implicit_gemm_problem_size( + cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size)), + iterator_A(Mma::IteratorA::getParams(args.problem_size, args.ref_A.layout())), + ptr_A(args.ref_A.data()), + iterator_B(Mma::IteratorB::getParams(args.problem_size, args.ref_B.layout())), + ptr_B(args.ref_B.data()), + ptr_reordered_B(args.ref_reordered_B.data()), + iterator_C(ConvOutputIteratorParameter::layout(args.ref_C), args.problem_size), + ptr_C(args.ref_C.data()), + iterator_D(ConvOutputIteratorParameter::layout(args.ref_D), args.problem_size), + ptr_D(args.ref_D.data()), + output_op(args.output_op), + semaphore(semaphore), + split_k_mode(args.split_k_mode), + split_k_slices(args.problem_size.split_k_slices) { + gemm_k_iterations = + depthwise_gemm_k_iterations(kConvolutionalOperator, + ThreadblockShape::kK, + args.problem_size, + kIteratorAlgorithm, + kGroupMode, + ThreadblockShape::kN); + + gemm_k_iterations_per_channel = implicit_gemm_k_iterations_per_channel( + kConvolutionalOperator, args.problem_size, kIteratorAlgorithm); + + ThreadblockSwizzle threadblock_swizzle; + + grid_tiled_shape = threadblock_swizzle.get_tiled_shape( + kConvolutionalOperator, + problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.problem_size.split_k_slices); + + swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape); + + // Dynamic SMEM usage because stride and dilation are runtime params. + smem_size_ = (cutlass::platform::max(iterator_A.activation_size, int(sizeof(typename Epilogue::SharedStorage))) * kStages + iterator_B.filter_size); + } + + CUTLASS_HOST_DEVICE + int get_smem_size() { + // Dynamic Smem Size + return smem_size_; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +template +struct ReorderKernel { + using Params = Params_; + using ElementB = ElementB_; + + union SharedStorage {}; + + static unsigned int const kReorderKernelThreadPerCTA = 128; + + CUTLASS_HOST_DEVICE + ReorderKernel() {} + + CUTLASS_HOST_DEVICE + static dim3 get_grid_shape(Params const ¶ms) { + return dim3{static_cast( + (params.problem_size.filter_size() + kReorderKernelThreadPerCTA - 1) / + kReorderKernelThreadPerCTA), + 1, + 1}; + } + + CUTLASS_HOST_DEVICE + static dim3 get_block_shape() { return dim3{kReorderKernelThreadPerCTA, 1, 1}; } + + CUTLASS_HOST_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + int64_t m = static_cast(params.problem_size.groups); + int64_t n = static_cast(params.problem_size.filter_size() / params.problem_size.K); + const ElementB *src_with_type = static_cast(params.ptr_B); + ElementB *dst_with_type = static_cast(params.ptr_reordered_B); + + int64_t linear_index = blockIdx.x * kReorderKernelThreadPerCTA + threadIdx.x; + int64_t index_m = linear_index / n; + int64_t index_n = linear_index % n; + int64_t new_linear_index = index_m + index_n * m; + + if (linear_index < m * n) { + dst_with_type[new_linear_index] = src_with_type[linear_index]; + } + return; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_, ///! Threadblock swizzling function + conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad) + typename ConvProblemSize_ = Conv2dProblemSize, ///! Convolutional operator on 2D or 3D problem + conv::GroupMode GroupMode_ = conv::GroupMode::kNone, ///! Group mode + typename ThreadBlockOutputShape_ = cutlass::conv::TensorNHWCShape<1, 1, 1, 1> +> +struct DirectConvolution { + + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + using ThreadBlockOutputShape = ThreadBlockOutputShape_; + static Operator const kConvolutionalOperator = ConvOperator; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename EpilogueOutputOp::ElementOutput; + + /// Set output tensor C layout + using LayoutC = LayoutA; + + using ElementAccumulator = typename EpilogueOutputOp::ElementAccumulator; + using ElementCompute = typename EpilogueOutputOp::ElementCompute; + + using WarpMmaOperator = typename Mma::Policy::Operator; + + using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; + using MathOperator = typename ArchMmaOperator::Operator; + + using OperatorClass = typename WarpMmaOperator::OperatorClass; + using ArchTag = typename WarpMmaOperator::ArchTag; + + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename WarpMmaOperator::Shape; + using InstructionShape = typename cutlass::gemm::GemmShape<1, 1, 1>; + + static int const kStages = Mma::kStages; + static IteratorAlgorithm const kIteratorAlgorithm = Mma::IteratorA::kIteratorAlgorithm; + static StrideSupport const kStrideSupport = Mma::IteratorA::kStrideSupport; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + using TensorRefA = typename Mma::IteratorA::TensorRef; + using TensorRefB = typename Mma::IteratorB::TensorRef; + using TensorRefC = cutlass::TensorRef; + + /// Check iterator A and B convolution dimension are the same and + // set device::ImplicitGemmConvolution::kConvDim + static_assert(Mma::IteratorA::kConvDim == Mma::IteratorB::kConvDim, + "Convolution on different different dimensions is not supported"); + static int const kConvDim = Mma::IteratorA::kConvDim; + + /// Conv dimension and problem size structure (Conv2d or Conv3d) + using ConvProblemSize = ConvProblemSize_; + + static conv::GroupMode const kGroupMode = GroupMode_; + + + // + // + // + using ConvOutputIteratorParameter = epilogue::threadblock::ConvOutputIteratorParameter< + LayoutC, + typename Epilogue::OutputTileIterator::Layout, + TensorRefC, + ConvOperator, + ConvProblemSize + >; + + + /// Argument structure + struct Arguments { + + // + // Data members + // + + ConvProblemSize problem_size; + TensorRefA ref_A; + TensorRefB ref_B; + TensorRefB ref_reordered_B; + TensorRefC ref_C; + TensorRefC ref_D; + typename EpilogueOutputOp::Params output_op; + SplitKMode split_k_mode; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments() { } + + CUTLASS_HOST_DEVICE + Arguments( + ConvProblemSize const & problem_size + ): + problem_size(problem_size) { } + + CUTLASS_HOST_DEVICE + Arguments( + ConvProblemSize const & problem_size, + TensorRefA const & ref_A, + TensorRefB const & ref_B, + TensorRefC const & ref_C, + TensorRefC const & ref_D, + typename EpilogueOutputOp::Params const & output_op, + TensorRefB const & ref_reordered_B = nullptr, + SplitKMode const & split_k_mode = SplitKMode::kSerial + ): + problem_size(problem_size), + ref_A(ref_A), + ref_B(ref_B), + ref_C(ref_C), + ref_D(ref_D), + output_op(output_op), + ref_reordered_B(ref_reordered_B), + split_k_mode(split_k_mode) + { + + } + + }; + + using Params = + typename cutlass::conv::kernel::DirectConvolutionParams; + + using ReorderKernel = typename cutlass::conv::kernel::ReorderKernel; + + /// Shared memory storage structure + union SharedStorage { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + DirectConvolution() { } + + /// Executes one ImplicitGEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_idx = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if threadblock is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_idx.m() || + params.grid_tiled_shape.n() <= threadblock_tile_idx.n()) { + + return; + } + + // Compute position within threadblock + int thread_idx = threadIdx.x; + int iterator_column_offset = 0; + int filter_row_offset = 0; + if (kGroupMode != GroupMode::kNone) { + if (kGroupMode == GroupMode::kDepthwise) { + iterator_column_offset += threadblock_tile_idx.n() * Mma::Shape::kN; + } + } + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + params.iterator_A, + params.problem_size, + params.ptr_A, + thread_idx, + MatrixCoord( + threadblock_tile_idx.m() + threadblock_tile_idx.k(), + iterator_column_offset + ) + ); + + typename Mma::IteratorB iterator_B( + params.iterator_B, + params.problem_size, + params.ptr_reordered_B, + thread_idx, + MatrixCoord( + filter_row_offset, + iterator_column_offset + ) + ); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + // Compute logical position within grid + threadblock_tile_idx = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + + MatrixCoord threadblock_offset( + threadblock_tile_idx.m() + threadblock_tile_idx.k(), + threadblock_tile_idx.n() * Mma::Shape::kN + ); + + // Tile iterator writing to destination tensor + typename Epilogue::OutputTileIterator iterator_D( + params.iterator_D, + params.ptr_D, + ConvOutputIteratorParameter::extent(params.problem_size), + thread_idx, + threadblock_offset + ); + + // Tile iterator reading from source accumulator tensor + typename Epilogue::OutputTileIterator iterator_C( + params.iterator_C, + params.ptr_C, + ConvOutputIteratorParameter::extent(params.problem_size), + thread_idx, + threadblock_offset + ); + + + // Construct the epilogue + Epilogue epilogue( + shared_storage.epilogue, + thread_idx, + warp_idx, + lane_idx); + + + // Compute threadblock-scoped matrix multiply-add + // Epilogue is fused in the mainloop + mma(params.gemm_k_iterations, + accumulators, + iterator_A, + params.iterator_A, + iterator_B, + params.iterator_B, + accumulators, + epilogue, + output_op, + iterator_D, + iterator_C, + params.split_k_slices); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/kernel/implicit_gemm_convolution.h b/include/cutlass/conv/kernel/implicit_gemm_convolution.h index d3f1a19f27..b1e0b477a8 100644 --- a/include/cutlass/conv/kernel/implicit_gemm_convolution.h +++ b/include/cutlass/conv/kernel/implicit_gemm_convolution.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -61,8 +61,9 @@ template < typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate typename Epilogue_, ///! Epilogue typename ThreadblockSwizzle_, ///! Threadblock swizzling function - conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad) - typename ConvProblemSize_ = Conv2dProblemSize ///! Convolutional operator on 2D or 3D problem + conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad, Deconv) + typename ConvProblemSize_ = Conv2dProblemSize, ///! Convolutional operator on 2D or 3D problem + conv::GroupMode GroupMode_ = conv::GroupMode::kNone ///! Group mode > struct ImplicitGemmConvolution { @@ -117,6 +118,8 @@ struct ImplicitGemmConvolution { /// Conv dimension and problem size structure (Conv2d or Conv3d) using ConvProblemSize = ConvProblemSize_; + static conv::GroupMode const kGroupMode = GroupMode_; + /// Wgrad C stride idx for implicit gemm algorithm // Conv2d row-major matrix C (KxRSC) // Conv3d row-major matrix C (KxTRSC) @@ -198,6 +201,7 @@ struct ImplicitGemmConvolution { int swizzle_log_tile; int gemm_k_iterations; + int gemm_k_iterations_per_channel; typename Mma::IteratorA::Params iterator_A; typename Mma::IteratorA::Element const *ptr_A; typename Mma::IteratorB::Params iterator_B; @@ -229,9 +233,9 @@ struct ImplicitGemmConvolution { ptr_A(args.ref_A.data()), iterator_B(args.problem_size, args.ref_B.layout()), ptr_B(args.ref_B.data()), - iterator_C(ConvOutputIteratorParameter::layout(args.ref_C)), + iterator_C(ConvOutputIteratorParameter::layout(args.ref_C), implicit_gemm_tensor_c_extent(kConvolutionalOperator, args.problem_size)), ptr_C(args.ref_C.data()), - iterator_D(ConvOutputIteratorParameter::layout(args.ref_D)), + iterator_D(ConvOutputIteratorParameter::layout(args.ref_D), implicit_gemm_tensor_c_extent(kConvolutionalOperator, args.problem_size)), ptr_D(args.ref_D.data()), output_op(args.output_op), semaphore(semaphore), @@ -241,7 +245,12 @@ struct ImplicitGemmConvolution { kConvolutionalOperator, ThreadblockShape::kK, args.problem_size, - kIteratorAlgorithm); + kIteratorAlgorithm, + kGroupMode, + ThreadblockShape::kN); + + gemm_k_iterations_per_channel = implicit_gemm_k_iterations_per_channel( + kConvolutionalOperator, args.problem_size, kIteratorAlgorithm); ThreadblockSwizzle threadblock_swizzle; @@ -286,6 +295,17 @@ struct ImplicitGemmConvolution { // Compute position within threadblock int thread_idx = threadIdx.x; + int iterator_A_column_offset = threadblock_tile_idx.k() * Mma::Shape::kK; + if (kGroupMode != GroupMode::kNone) { + if (kGroupMode != GroupMode::kDepthwise) { + int k_per_group = params.problem_size.K / params.problem_size.groups; + int group_idx = threadblock_tile_idx.n() * Mma::Shape::kN / k_per_group; + int channels_per_group = params.problem_size.C / params.problem_size.groups; + iterator_A_column_offset += group_idx * channels_per_group; + } else { + iterator_A_column_offset += threadblock_tile_idx.n() * Mma::Shape::kN; + } + } // Construct iterators to A and B operands typename Mma::IteratorA iterator_A( @@ -295,7 +315,7 @@ struct ImplicitGemmConvolution { thread_idx, MatrixCoord( threadblock_tile_idx.m() * Mma::Shape::kM, - threadblock_tile_idx.k() * Mma::Shape::kK + iterator_A_column_offset ) ); @@ -312,7 +332,7 @@ struct ImplicitGemmConvolution { // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. - int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int warp_idx = canonical_warp_idx_sync(); int lane_idx = threadIdx.x % 32; // @@ -327,7 +347,7 @@ struct ImplicitGemmConvolution { accumulators.clear(); // Compute threadblock-scoped matrix multiply-add - mma(params.gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); + mma(params.gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators, params.gemm_k_iterations_per_channel); // // Epilogue @@ -377,7 +397,6 @@ struct ImplicitGemmConvolution { threadblock_offset ); - // Construct the epilogue Epilogue epilogue( shared_storage.epilogue, diff --git a/include/cutlass/conv/kernel/implicit_gemm_convolution_fusion.h b/include/cutlass/conv/kernel/implicit_gemm_convolution_fusion.h index d43521f155..74ecae4014 100644 --- a/include/cutlass/conv/kernel/implicit_gemm_convolution_fusion.h +++ b/include/cutlass/conv/kernel/implicit_gemm_convolution_fusion.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -119,6 +119,8 @@ struct ImplicitGemmConvolutionFusion { /// Conv dimension and problem size structure (Conv2d or Conv3d) using ConvProblemSize = ConvProblemSize_; + static conv::GroupMode const kGroupMode = conv::GroupMode::kNone; + /// Wgrad C stride idx for implicit gemm algorithm // Conv2d row-major matrix C (KxRSC) // Conv3d row-major matrix C (KxTRSC) @@ -200,32 +202,30 @@ struct ImplicitGemmConvolutionFusion { /// Parameters structure struct Params { - ConvProblemSize problem_size; - cutlass::gemm::GemmCoord grid_tiled_shape; - gemm::GemmCoord implicit_gemm_problem_size; - int swizzle_log_tile; - int gemm_k_iterations; - typename Mma::IteratorA::Params iterator_A; - typename Mma::IteratorA::Element const *ptr_A; - typename Mma::IteratorB::Params iterator_B; - typename Mma::IteratorB::Element const *ptr_B; - typename Mma::IteratorScaleBias::Params iterator_scale_bias; - typename Mma::IteratorScaleBias::Element const *ptr_scale; - typename Mma::IteratorScaleBias::Element const *ptr_bias; - typename Epilogue::OutputTileIterator::Params iterator_C; - typename Epilogue::OutputTileIterator::Element *ptr_C; - typename Epilogue::OutputTileIterator::Params iterator_D; - typename Epilogue::OutputTileIterator::Element *ptr_D; - typename EpilogueOutputOp::Params output_op; - int *semaphore; - SplitKMode split_k_mode; + ConvProblemSize problem_size{}; + cutlass::gemm::GemmCoord grid_tiled_shape{}; + gemm::GemmCoord implicit_gemm_problem_size{}; + int swizzle_log_tile{0}; + int gemm_k_iterations{0}; + typename Mma::IteratorA::Params iterator_A{}; + typename Mma::IteratorA::Element const *ptr_A = nullptr; + typename Mma::IteratorB::Params iterator_B{}; + typename Mma::IteratorB::Element const *ptr_B = nullptr; + typename Mma::IteratorScaleBias::Params iterator_scale_bias{}; + typename Mma::IteratorScaleBias::Element const *ptr_scale = nullptr; + typename Mma::IteratorScaleBias::Element const *ptr_bias = nullptr; + typename Epilogue::OutputTileIterator::Params iterator_C {}; + typename Epilogue::OutputTileIterator::Element *ptr_C = nullptr; + typename Epilogue::OutputTileIterator::Params iterator_D {}; + typename Epilogue::OutputTileIterator::Element *ptr_D = nullptr; + typename EpilogueOutputOp::Params output_op {}; + int *semaphore = nullptr; + SplitKMode split_k_mode {}; // // Methods // - - CUTLASS_HOST_DEVICE - Params(): swizzle_log_tile(0), gemm_k_iterations(0) { } + Params() = default; /// CUTLASS_HOST_DEVICE @@ -337,7 +337,7 @@ struct ImplicitGemmConvolutionFusion { // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. - int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int warp_idx = canonical_warp_idx_sync(); int lane_idx = threadIdx.x % 32; // diff --git a/include/cutlass/conv/kernel/implicit_gemm_convolution_strided_dgrad.h b/include/cutlass/conv/kernel/implicit_gemm_convolution_strided_dgrad.h index 65191f5a6a..bf00f90bac 100644 --- a/include/cutlass/conv/kernel/implicit_gemm_convolution_strided_dgrad.h +++ b/include/cutlass/conv/kernel/implicit_gemm_convolution_strided_dgrad.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -117,6 +117,8 @@ struct ImplicitGemmConvolutionStridedDgrad { /// Conv dimension and problem size structure (Conv2d or Conv3d) using ConvProblemSize = ConvProblemSize_; + static conv::GroupMode const kGroupMode = conv::GroupMode::kNone; + /// Wgrad C stride idx for implicit gemm algorithm // Conv2d row-major matrix C (KxRSC) // Conv3d row-major matrix C (KxTRSC) @@ -156,21 +158,20 @@ struct ImplicitGemmConvolutionStridedDgrad { // Data members // - ConvProblemSize problem_size; - TensorRefA ref_A; - TensorRefB ref_B; - TensorRefC ref_C; - TensorRefC ref_D; - typename EpilogueOutputOp::Params output_op; - SplitKMode split_k_mode; + ConvProblemSize problem_size{}; + TensorRefA ref_A{}; + TensorRefB ref_B{}; + TensorRefC ref_C{}; + TensorRefC ref_D{}; + typename EpilogueOutputOp::Params output_op{}; + SplitKMode split_k_mode{}; // // Methods // /// Default ctor - CUTLASS_HOST_DEVICE - Arguments() { } + Arguments() = default; CUTLASS_HOST_DEVICE Arguments( @@ -203,29 +204,28 @@ struct ImplicitGemmConvolutionStridedDgrad { /// Parameters structure struct Params { - ConvProblemSize problem_size; - cutlass::gemm::GemmCoord grid_tiled_shape; - FastDivmod stride_h_divmod; - FastDivmod stride_w_divmod; - int gemm_k_iterations; - typename Mma::IteratorA::Params iterator_A; - typename Mma::IteratorA::Element const *ptr_A; - typename Mma::IteratorB::Params iterator_B; - typename Mma::IteratorB::Element const *ptr_B; - typename Epilogue::OutputTileIterator::Params iterator_C; - typename Epilogue::OutputTileIterator::Element *ptr_C; - typename Epilogue::OutputTileIterator::Params iterator_D; - typename Epilogue::OutputTileIterator::Element *ptr_D; - typename EpilogueOutputOp::Params output_op; - int *semaphore; - SplitKMode split_k_mode; + ConvProblemSize problem_size{}; + cutlass::gemm::GemmCoord grid_tiled_shape{}; + int swizzle_log_tile{0}; + FastDivmod stride_h_divmod{}; + FastDivmod stride_w_divmod{}; + int gemm_k_iterations{0}; + typename Mma::IteratorA::Params iterator_A{}; + typename Mma::IteratorA::Element const *ptr_A = nullptr; + typename Mma::IteratorB::Params iterator_B{}; + typename Mma::IteratorB::Element const *ptr_B = nullptr; + typename Epilogue::OutputTileIterator::Params iterator_C{}; + typename Epilogue::OutputTileIterator::Element *ptr_C = nullptr; + typename Epilogue::OutputTileIterator::Params iterator_D{}; + typename Epilogue::OutputTileIterator::Element *ptr_D = nullptr; + typename EpilogueOutputOp::Params output_op {}; + int *semaphore = nullptr; + SplitKMode split_k_mode {}; // // Methods // - - CUTLASS_HOST_DEVICE - Params(): gemm_k_iterations(0) { } + Params() = default; /// CUTLASS_HOST_DEVICE @@ -257,6 +257,8 @@ struct ImplicitGemmConvolutionStridedDgrad { args.problem_size, {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, args.problem_size.split_k_slices); + + swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape); } }; @@ -281,7 +283,7 @@ struct ImplicitGemmConvolutionStridedDgrad { ThreadblockSwizzle threadblock_swizzle; cutlass::gemm::GemmCoord threadblock_tile_idx = - threadblock_swizzle.get_tile_offset(params.grid_tiled_shape); + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); // Early exit if CTA is out of range if (params.grid_tiled_shape.m() <= threadblock_tile_idx.m() || @@ -333,7 +335,7 @@ struct ImplicitGemmConvolutionStridedDgrad { // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. - int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int warp_idx = canonical_warp_idx_sync(); int lane_idx = threadIdx.x % 32; // Check if CTA contributes valid MMA (Dy * w) and accumulator will be non-zero after MMA @@ -387,16 +389,15 @@ struct ImplicitGemmConvolutionStridedDgrad { // Construct the semaphore. int block_idx = threadblock_tile_idx.m() + threadblock_tile_idx.n() * params.grid_tiled_shape.m(); - Semaphore semaphore(params.semaphore + block_idx, thread_idx); - + // Compute logical position within grid threadblock_tile_idx = - threadblock_swizzle.get_tile_offset(params.grid_tiled_shape); + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); // If performing a reduction via split-K, fetch the initial synchronization if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { - + // Fetch the synchronization lock initially but do not block. semaphore.fetch(); @@ -419,51 +420,51 @@ struct ImplicitGemmConvolutionStridedDgrad { start_r, start_s, threadblock_offset ); - - // Tile iterator reading from source accumulator tensor - typename Epilogue::OutputTileIterator iterator_C( - params.iterator_C, - params.ptr_C, - ConvOutputIteratorParameter::extent(params.problem_size), - thread_idx, - params.stride_h_divmod, params.stride_w_divmod, - start_r, start_s, - threadblock_offset - ); - // Construct the epilogue Epilogue epilogue( - shared_storage.epilogue, - thread_idx, - warp_idx, + shared_storage.epilogue, + thread_idx, + warp_idx, lane_idx); - // Wait on the semaphore - this latency may have been covered by iterator construction - if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { - - // For subsequent threadblocks, the source matrix is held in the 'D' tensor. - if (threadblock_tile_idx.k()) { - iterator_C = iterator_D; - } + if (output_op.is_source_needed()) + { + // Tile iterator reading from source accumulator tensor + typename Epilogue::OutputTileIterator iterator_C( + params.iterator_C, + params.ptr_C, + ConvOutputIteratorParameter::extent(params.problem_size), + thread_idx, + params.stride_h_divmod, params.stride_w_divmod, + start_r, start_s, + threadblock_offset); - semaphore.wait(threadblock_tile_idx.k()); + // Wait on the semaphore - this latency may have been covered by iterator construction + if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { + // For subsequent threadblocks, the source matrix is held in the 'D' tensor. + if (threadblock_tile_idx.k()) { + iterator_C = iterator_D; + } + + semaphore.wait(threadblock_tile_idx.k()); + } + + // Run epilogue with addend source iterator + epilogue(output_op, iterator_D, accumulators, iterator_C); } - // Each split-k-slice writes to a unique tensor location - else if (params.split_k_mode == SplitKMode::kParallel) { - iterator_D.add_pointer_offset(threadblock_tile_idx.k() * - cutlass::conv::implicit_gemm_tensor_c_size(ConvOperator, params.problem_size)); + else + { + // Run epilogue without addend source iterator + epilogue(output_op, iterator_D, accumulators); } - // Run efficient epilogue - epilogue(output_op, iterator_D, accumulators, iterator_C); - // // Release the semaphore // - if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { + if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { int lock = 0; if (params.grid_tiled_shape.k() == threadblock_tile_idx.k() + 1) { @@ -475,10 +476,11 @@ struct ImplicitGemmConvolutionStridedDgrad { // Otherwise, the semaphore is incremented lock = threadblock_tile_idx.k() + 1; } - + semaphore.release(lock); } - } + + } }; ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -488,4 +490,3 @@ struct ImplicitGemmConvolutionStridedDgrad { } // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// - diff --git a/include/cutlass/conv/kernel/implicit_gemm_convolution_with_absmax.h b/include/cutlass/conv/kernel/implicit_gemm_convolution_with_absmax.h new file mode 100644 index 0000000000..b05fd2d3ed --- /dev/null +++ b/include/cutlass/conv/kernel/implicit_gemm_convolution_with_absmax.h @@ -0,0 +1,494 @@ +/*************************************************************************************************** + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Convolution kernel with an epilogue that computes the absolute maximum value of the output + and a pre-activation-function auxiliary output. The auxiliary output is also (optionally) + stored to global memory. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cutlass/aligned_buffer.h" +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/semaphore.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/conv3d_problem_size.h" +#include "cutlass/epilogue/threadblock/output_iterator_parameter.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace kernel { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate + typename Epilogue_, ///! Epilogue + typename ThreadblockSwizzle_, ///! Threadblock swizzling function + conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad) + typename ConvProblemSize_ = Conv2dProblemSize ///! Convolutional operator on 2D or 3D problem +> +struct ImplicitGemmConvolutionWithAbsMax { + + using Mma = Mma_; + using Epilogue = Epilogue_; + using EpilogueOutputOp = typename Epilogue::OutputOp; + using ThreadblockSwizzle = ThreadblockSwizzle_; + static Operator const kConvolutionalOperator = ConvOperator; + + using ElementA = typename Mma::IteratorA::Element; + using LayoutA = typename Mma::IteratorA::Layout; + using ElementB = typename Mma::IteratorB::Element; + using LayoutB = typename Mma::IteratorB::Layout; + using ElementC = typename EpilogueOutputOp::ElementOutput; + + /// Set output tensor C layout + using LayoutC = LayoutA; + + using ElementAccumulator = typename EpilogueOutputOp::ElementAccumulator; + using ElementCompute = typename EpilogueOutputOp::ElementCompute; + + using WarpMmaOperator = typename Mma::Policy::Operator; + + using ArchMmaOperator = typename WarpMmaOperator::ArchMmaOperator; + using MathOperator = typename ArchMmaOperator::Operator; + + using OperatorClass = typename WarpMmaOperator::OperatorClass; + using ArchTag = typename WarpMmaOperator::ArchTag; + + using ThreadblockShape = typename Mma::Shape; + using WarpShape = typename WarpMmaOperator::Shape; + using InstructionShape = typename ArchMmaOperator::Shape; + + static int const kStages = Mma::kStages; + static IteratorAlgorithm const kIteratorAlgorithm = Mma::IteratorA::kIteratorAlgorithm; + static StrideSupport const kStrideSupport = Mma::IteratorA::kStrideSupport; + + /// Warp count (concept: GemmShape) + using WarpCount = typename Mma::WarpCount; + static int const kThreadCount = 32 * WarpCount::kCount; + + using TensorRefA = typename Mma::IteratorA::TensorRef; + using TensorRefB = typename Mma::IteratorB::TensorRef; + using TensorRefC = cutlass::TensorRef; + using TensorRefAux = cutlass::TensorRef; + + /// Check iterator A and B convolution dimension are the same and + // set device::ImplicitGemmConvolution::kConvDim + static_assert(Mma::IteratorA::kConvDim == Mma::IteratorB::kConvDim, + "Convolution on different different dimensions is not supported"); + static int const kConvDim = Mma::IteratorA::kConvDim; + + /// Conv dimension and problem size structure (Conv2d or Conv3d) + using ConvProblemSize = ConvProblemSize_; + + static conv::GroupMode const kGroupMode = conv::GroupMode::kNone; + + /// Wgrad C stride idx for implicit gemm algorithm + // Conv2d row-major matrix C (KxRSC) + // Conv3d row-major matrix C (KxTRSC) + static int const kWgradCStrideIdx = + platform::is_same::value ? 2 : 3; + + /// This chooses the appropriate stride element of the C tensor. + static int const kTensorCStrideIdx = + (kConvolutionalOperator == conv::Operator::kWgrad ? kWgradCStrideIdx : 0); + + // + // + // + using ConvOutputIteratorParameter = epilogue::threadblock::ConvOutputIteratorParameter< + LayoutC, + typename Epilogue::OutputTileIterator::Layout, + TensorRefC, + ConvOperator, + ConvProblemSize + >; + + /// Argument structure + struct Arguments { + + // + // Data members + // + + ConvProblemSize problem_size; + TensorRefA ref_A; + TensorRefB ref_B; + TensorRefC ref_C; + TensorRefC ref_D; + TensorRefC ref_Aux; + + typename EpilogueOutputOp::Params output_op; + SplitKMode split_k_mode; + + void * ptr_Vector; + + typename LayoutC::Stride::Index ldr; + + // + // Methods + // + + /// Default ctor + CUTLASS_HOST_DEVICE + Arguments() { } + + CUTLASS_HOST_DEVICE + Arguments( + ConvProblemSize const & problem_size + ): + problem_size(problem_size) { } + + CUTLASS_HOST_DEVICE + Arguments( + ConvProblemSize const & problem_size, + TensorRefA const & ref_A, + TensorRefB const & ref_B, + TensorRefC const & ref_C, + TensorRefC const & ref_D, + TensorRefAux const & ref_Aux, + typename EpilogueOutputOp::Params const & output_op, + SplitKMode const & split_k_mode = SplitKMode::kSerial, + void * ptr_Vector = nullptr, + typename LayoutC::Stride::Index ldr = 0 + ): + problem_size(problem_size), + ref_A(ref_A), + ref_B(ref_B), + ref_C(ref_C), + ref_D(ref_D), + ref_Aux(ref_Aux), + output_op(output_op), + split_k_mode(split_k_mode), + ptr_Vector(ptr_Vector), + ldr(ldr) + { + + } + + }; + + /// Parameters structure + struct Params { + ConvProblemSize problem_size; + cutlass::gemm::GemmCoord grid_tiled_shape; + gemm::GemmCoord implicit_gemm_problem_size; + int swizzle_log_tile; + + int gemm_k_iterations; + typename Mma::IteratorA::Params iterator_A; + typename Mma::IteratorA::Element const *ptr_A; + typename Mma::IteratorB::Params iterator_B; + typename Mma::IteratorB::Element const *ptr_B; + typename Epilogue::OutputTileIterator::Params iterator_C; + typename Epilogue::OutputTileIterator::Element *ptr_C; + typename Epilogue::OutputTileIterator::Params iterator_D; + typename Epilogue::OutputTileIterator::Element *ptr_D; + typename Epilogue::AuxOutputTileIterator::Params iterator_Aux; + typename Epilogue::AuxOutputTileIterator::Element *ptr_Aux; + typename EpilogueOutputOp::Params output_op; + int *semaphore; + SplitKMode split_k_mode; + + void * ptr_Vector; + typename LayoutC::Stride::Index ldr; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Params(): + swizzle_log_tile(0), + gemm_k_iterations(0), + ptr_Vector(nullptr), + ldr(0) + { } + + /// + CUTLASS_HOST_DEVICE + Params( + Arguments const &args, + int *semaphore = nullptr + ): + problem_size(args.problem_size), + implicit_gemm_problem_size(cutlass::conv::implicit_gemm_problem_size(kConvolutionalOperator, args.problem_size)), + iterator_A(Mma::IteratorA::getParams(args.problem_size, args.ref_A.layout())), + ptr_A(args.ref_A.data()), + iterator_B(args.problem_size, args.ref_B.layout()), + ptr_B(args.ref_B.data()), + iterator_C(ConvOutputIteratorParameter::layout(args.ref_C)), + ptr_C(args.ref_C.data()), + iterator_D(ConvOutputIteratorParameter::layout(args.ref_D)), + ptr_D(args.ref_D.data()), + iterator_Aux(ConvOutputIteratorParameter::layout(args.ref_Aux)), + ptr_Aux(args.ref_Aux.data()), + output_op(args.output_op), + semaphore(semaphore), + split_k_mode(args.split_k_mode), + ptr_Vector(args.ptr_Vector), + ldr(args.ldr) + + { + gemm_k_iterations = implicit_gemm_k_iterations(kConvolutionalOperator, ThreadblockShape::kK, args.problem_size); + + ThreadblockSwizzle threadblock_swizzle; + + grid_tiled_shape = threadblock_swizzle.get_tiled_shape( + implicit_gemm_problem_size, + {ThreadblockShape::kM, ThreadblockShape::kN, ThreadblockShape::kK}, + args.problem_size.split_k_slices); + + swizzle_log_tile = threadblock_swizzle.get_log_tile(grid_tiled_shape); + } + }; + + /// Shared memory storage structure + union SharedStorage { + typename Mma::SharedStorage main_loop; + typename Epilogue::SharedStorage epilogue; + }; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + ImplicitGemmConvolutionWithAbsMax() { } + + /// Executes one ImplicitGEMM + CUTLASS_DEVICE + void operator()(Params const ¶ms, SharedStorage &shared_storage) { + + // Compute threadblock location + ThreadblockSwizzle threadblock_swizzle; + + cutlass::gemm::GemmCoord threadblock_tile_idx = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // Early exit if CTA is out of range + if (params.grid_tiled_shape.m() <= threadblock_tile_idx.m() || + params.grid_tiled_shape.n() <= threadblock_tile_idx.n()) { + + return; + } + + // Compute position within threadblock + int thread_idx = threadIdx.x; + + // Construct iterators to A and B operands + typename Mma::IteratorA iterator_A( + params.iterator_A, + params.problem_size, + params.ptr_A, + thread_idx, + MatrixCoord( + threadblock_tile_idx.m() * Mma::Shape::kM, + threadblock_tile_idx.k() * Mma::Shape::kK + ) + ); + + typename Mma::IteratorB iterator_B( + params.iterator_B, + params.problem_size, + params.ptr_B, + thread_idx, + MatrixCoord( + threadblock_tile_idx.k() * Mma::Shape::kK, + threadblock_tile_idx.n() * Mma::Shape::kN + ) + ); + + // Broadcast the warp_id computed by lane 0 to ensure dependent code + // is compiled as warp-uniform. + int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int lane_idx = threadIdx.x % 32; + + // + // Main loop + // + + // Construct thread-scoped matrix multiply + Mma mma(shared_storage.main_loop, thread_idx, warp_idx, lane_idx); + + typename Mma::FragmentC accumulators; + + accumulators.clear(); + + // Compute threadblock-scoped matrix multiply-add + mma(params.gemm_k_iterations, accumulators, iterator_A, iterator_B, accumulators); + + // + // Epilogue + // + + EpilogueOutputOp output_op(params.output_op); + + // Construct the semaphore. + int block_idx = threadblock_tile_idx.m() + threadblock_tile_idx.n() * params.grid_tiled_shape.m(); + + Semaphore semaphore(params.semaphore + block_idx, thread_idx); + + // Compute logical position within grid + threadblock_tile_idx = + threadblock_swizzle.get_tile_offset(params.swizzle_log_tile); + + // If performing a reduction via split-K, fetch the initial synchronization + if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { + + // Fetch the synchronization lock initially but do not block. + semaphore.fetch(); + + // Indicate which position in a serial reduction the output operator is currently updating + output_op.set_k_partition(threadblock_tile_idx.k(), params.grid_tiled_shape.k()); + } + + MatrixCoord threadblock_offset( + threadblock_tile_idx.m() * Mma::Shape::kM, + threadblock_tile_idx.n() * Mma::Shape::kN + ); + + // Tile iterator writing to destination tensor + typename Epilogue::OutputTileIterator iterator_D( + params.iterator_D, + params.ptr_D, + ConvOutputIteratorParameter::extent(params.problem_size), + thread_idx, + threadblock_offset + ); + + // Tile iterator writing to auxiliary tensor. + typename Epilogue::AuxOutputTileIterator iterator_Aux( + params.iterator_Aux, + params.ptr_Aux, + ConvOutputIteratorParameter::extent(params.problem_size), + thread_idx, + threadblock_offset + ); + + // Tile iterator reading from source accumulator tensor + typename Epilogue::OutputTileIterator iterator_C( + params.iterator_C, + params.ptr_C, + ConvOutputIteratorParameter::extent(params.problem_size), + thread_idx, + threadblock_offset + ); + + // Define the reduction output pointer and move to the appropriate place + typename Epilogue::ElementVector *ptr_Vector = + static_cast(params.ptr_Vector); + + + // Construct the epilogue + Epilogue epilogue( + shared_storage.epilogue, + thread_idx, + warp_idx, + lane_idx); + + // Move to appropriate location for this output tile + if (ptr_Vector) { + ptr_Vector += threadblock_offset.column() + threadblock_tile_idx.m() * params.ldr; + } + + // Wait on the semaphore - this latency may have been covered by iterator construction + if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { + + // For subsequent threadblocks, the source matrix is held in the 'D' tensor. + if (threadblock_tile_idx.k()) { + iterator_C = iterator_D; + } + + semaphore.wait(threadblock_tile_idx.k()); + + } + // Each split-k-slice writes to a unique tensor location + else if (params.split_k_mode == SplitKMode::kParallel) { + iterator_D.add_pointer_offset(threadblock_tile_idx.k() * + cutlass::conv::implicit_gemm_tensor_c_size(ConvOperator, params.problem_size)); + } + + // Execute the epilogue operator to update the destination tensor. + epilogue(output_op, + // Only the final block uses Vector + ((params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) && + (params.grid_tiled_shape.k() != threadblock_tile_idx.k() + 1)) + ? nullptr + : ptr_Vector, + iterator_D, + accumulators, + iterator_C, + iterator_Aux, + ConvOutputIteratorParameter::extent(params.problem_size), + threadblock_offset); + + // + // Release the semaphore + // + + if (params.split_k_mode == SplitKMode::kSerial && params.grid_tiled_shape.k() > 1) { + + int lock = 0; + if (params.grid_tiled_shape.k() == threadblock_tile_idx.k() + 1) { + + // The final threadblock resets the semaphore for subsequent grids. + lock = 0; + } + else { + // Otherwise, the semaphore is incremented + lock = threadblock_tile_idx.k() + 1; + } + + semaphore.release(lock); + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace kernel +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h b/include/cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h index 2ab47637a0..1f27e0686d 100644 --- a/include/cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h +++ b/include/cutlass/conv/kernel/implicit_gemm_convolution_with_fused_epilogue.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -61,7 +61,7 @@ template < typename Mma_, ///! Threadblock-scoped matrix multiply-accumulate typename Epilogue_, ///! Epilogue typename ThreadblockSwizzle_, ///! Threadblock swizzling function - conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad) + conv::Operator ConvOperator, ///! Convolutional operator (Fprop, Dgrad, Wgrad, Deconv) typename ConvProblemSize_ = Conv2dProblemSize ///! Convolutional operator on 2D or 3D problem > struct ImplicitGemmConvolutionWithFusedEpilogue { @@ -117,6 +117,8 @@ struct ImplicitGemmConvolutionWithFusedEpilogue { /// Conv dimension and problem size structure (Conv2d or Conv3d) using ConvProblemSize = ConvProblemSize_; + static conv::GroupMode const kGroupMode = conv::GroupMode::kNone; + /// Wgrad C stride idx for implicit gemm algorithm // Conv2d row-major matrix C (KxRSC) // Conv3d row-major matrix C (KxTRSC) @@ -255,9 +257,9 @@ struct ImplicitGemmConvolutionWithFusedEpilogue { ptr_A(args.ref_A.data()), iterator_B(args.problem_size, args.ref_B.layout()), ptr_B(args.ref_B.data()), - iterator_C(ConvOutputIteratorParameter::layout(args.ref_C)), + iterator_C(ConvOutputIteratorParameter::layout(args.ref_C), implicit_gemm_tensor_c_extent(kConvolutionalOperator, args.problem_size)), ptr_C(args.ref_C.data()), - iterator_D(ConvOutputIteratorParameter::layout(args.ref_D)), + iterator_D(ConvOutputIteratorParameter::layout(args.ref_D), implicit_gemm_tensor_c_extent(kConvolutionalOperator, args.problem_size)), ptr_D(args.ref_D.data()), output_op(args.output_op), semaphore(semaphore), @@ -339,7 +341,7 @@ struct ImplicitGemmConvolutionWithFusedEpilogue { // Broadcast the warp_id computed by lane 0 to ensure dependent code // is compiled as warp-uniform. - int warp_idx = __shfl_sync(0xffffffff, threadIdx.x / 32, 0); + int warp_idx = canonical_warp_idx_sync(); int lane_idx = threadIdx.x % 32; // diff --git a/include/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp b/include/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp new file mode 100644 index 0000000000..657ac6b3ec --- /dev/null +++ b/include/cutlass/conv/kernel/sm90_implicit_gemm_tma_warpspecialized.hpp @@ -0,0 +1,76 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/kernel_hardware_info.hpp" + +#include "cute/tensor.hpp" +#include "cute/arch/cluster_sm90.hpp" + +#include "cutlass/conv/detail.hpp" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/dispatch_policy.hpp" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/pipeline/sm90_pipeline.hpp" +#include "cutlass/gemm/kernel/gemm_universal.hpp" +#include "cutlass/gemm/kernel/tile_scheduler.hpp" + +/////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::conv::kernel { + +/////////////////////////////////////////////////////////////////////////////// + +template < + class ProblemShape_, + class CollectiveMainloop_, + class CollectiveEpilogue_, + class TileScheduler_ +> +class ConvUniversal< + ProblemShape_, + CollectiveMainloop_, + CollectiveEpilogue_, + TileScheduler_, + cute::enable_if_t> +> : public cutlass::gemm::kernel::GemmUniversal< + ProblemShape_, + CollectiveMainloop_, + CollectiveEpilogue_, + TileScheduler_ +> +{}; +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::conv::kernel + diff --git a/include/cutlass/conv/thread/depthwise_mma.h b/include/cutlass/conv/thread/depthwise_mma.h new file mode 100644 index 0000000000..37ece7927e --- /dev/null +++ b/include/cutlass/conv/thread/depthwise_mma.h @@ -0,0 +1,325 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates exposing architecture support for depthwise convolution +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/arch/mma.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/thread/mma.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace thread { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// MMA operation +template < + /// Size of the matrix product (concept: GemmShape) + typename Shape_, + /// Number of threads participating + int kThreads_, + /// Data type of A elements + typename ElementA, + /// Data type of B elements + typename ElementB, + /// Element type of C matrix + typename ElementC, + /// Inner product operator + typename Operator +> +struct ElementwiseInnerProduct; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// General implementation +template < + /// Size of the matrix product (concept: GemmShape) + typename Shape_, + /// Data type of A elements + typename ElementA_, + /// Data type of B elements + typename ElementB_, + /// Element type of C matrix + typename ElementC_> +struct ElementwiseInnerProduct { + using Shape = Shape_; + using Operator = arch::OpMultiplyAdd; + using ElementC = ElementC_; + + CUTLASS_HOST_DEVICE + void operator()(Array &d, + Array const &a, + Array const &b, + Array const &c) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < Shape::kN; ++i) { + d[i] = a[i] * b[i] + c[i]; + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Specialization of half_t +template <> +struct ElementwiseInnerProduct< + gemm::GemmShape<2, 2, 1>, + 1, + half_t, + half_t, + half_t, + arch::OpMultiplyAdd> { + + using Shape = gemm::GemmShape<2, 2, 1>; + using Operator = arch::OpMultiplyAdd; + using ElementC = half_t; + + CUTLASS_HOST_DEVICE + void operator()( + Array &d, + Array const &a, + Array const &b, + Array const &c + ) { + +#if (defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 600)) + + __half2 const & A = reinterpret_cast<__half2 const &>(a); + __half2 const & B = reinterpret_cast<__half2 const &>(b); + __half2 const & C = reinterpret_cast<__half2 const &>(c); + + __half2 tmp_D = __hfma2(A, B, C); + + d = reinterpret_cast const &>(tmp_D); + +#else + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < 2; ++i) { + d[i] = a[i] * b[i] + c[i]; + } +#endif + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape, + /// Data type of A elements + typename ElementA, + /// Data type of B elements + typename ElementB, + /// Element type of C matrix + typename ElementC, + /// Concept: arch::OpMultiplyAdd or arch::Mma<> + typename Operator = arch::OpMultiplyAdd, + /// Used for partial specialization + typename Enable = bool +> +struct DepthwiseDirectConvElementwiseInnerProduct; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Gemplate that handles all packed matrix layouts +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Data type of A elements + typename ElementA_, + /// Data type of B elements + typename ElementB_, + /// Element type of C matrix + typename ElementC_, + /// Operator used to compute GEMM + typename Operator_ +> +struct DepthwiseDirectConvElementwiseInnerProductGeneric { + + /// Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + + /// Data type of operand A + using ElementA = ElementA_; + + /// Data type of operand B + using ElementB = ElementB_; + + /// Element type of operand C + using ElementC = ElementC_; + + /// Underlying mathematical operator + using Operator = Operator_; + + /// A operand storage + using FragmentA = Array; + + /// B operand storage + using FragmentB = Array; + + /// C operand storage + using FragmentC = Array; + + /// Instruction + using MmaOp = cutlass::conv::thread::ElementwiseInnerProduct< + gemm::GemmShape, + 1, + ElementA, + ElementB, + ElementC, + Operator>; + + + // + // Methods + // + + /// Computes a matrix product D = A * B + C + CUTLASS_HOST_DEVICE + void operator()( + FragmentC & D, + FragmentA const & A, + FragmentB const & B, + FragmentC const & C) { + Array *ptr_D = reinterpret_cast *>(&D); + Array const *ptr_A = + reinterpret_cast const *>(&A); + Array const *ptr_B = + reinterpret_cast const *>(&B); + + MmaOp mma_op; + + // Copy accumulators + D = C; + + // Compute matrix product + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Shape::kN / MmaOp::Shape::kN; ++n) { + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < Shape::kM; ++m) { + + Array tmpD = ptr_D[m * Shape::kN / MmaOp::Shape::kN + n]; + Array tmpA = ptr_A[m * Shape::kN / MmaOp::Shape::kN + n]; + Array tmpB = ptr_B[n]; + + mma_op(tmpD, tmpA, tmpB, tmpD); + + ptr_D[m * Shape::kN / MmaOp::Shape::kN + n] = tmpD; + + } + } + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Data type of A elements + typename ElementA_, + /// Data type of B elements + typename ElementB_, + /// Element type of C matrix + typename ElementC_ +> +struct DepthwiseDirectConvElementwiseInnerProduct< + Shape_, + ElementA_, + ElementB_, + ElementC_, + arch::OpMultiplyAdd + > { + /// Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + + /// Data type of operand A + using ElementA = ElementA_; + + /// Data type of operand B + using ElementB = ElementB_; + + /// Element type of operand C + using ElementC = ElementC_; + + /// Underlying mathematical operator + using Operator = arch::OpMultiplyAdd; + + /// A operand storage + using FragmentA = + Array; // output_tile_size per thread * groups_per_thread + + /// B operand storage + using FragmentB = Array; // 1 * groups_per_thread + + /// C operand storage + using FragmentC = + Array; // output_tile_size per thread * groups_per_thread + + static bool const use_optimized = 0; + + using ArchMmaOperator = DepthwiseDirectConvElementwiseInnerProductGeneric; + + // + // Methods + // + + /// Computes a matrix product D = A * B + C + CUTLASS_HOST_DEVICE + void operator()( + FragmentC & D, + FragmentA const & A, + FragmentB const & B, + FragmentC const & C) { + + ArchMmaOperator mma; + + mma(D, A, B, C); + + } +}; + +} // namespace thread +} // namespace conv +} // namespace cutlass diff --git a/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_analytic.h index 0af34babfd..978c14feb6 100644 --- a/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_analytic.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_optimized.h index 4014173b48..6fb1cb18e9 100644 --- a/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv2d_dgrad_filter_tile_access_iterator_optimized.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -248,7 +248,7 @@ class Conv2dDgradFilterTileAccessIteratorOptimized < pointer_ += pointer_offset * sizeof_bits::value / 8; } - CUTLASS_HOST_DEVICE + CUTLASS_DEVICE void advance() { int next_idx = 0; @@ -263,18 +263,33 @@ class Conv2dDgradFilterTileAccessIteratorOptimized < // Move filter_r by stride_h filter_r_ += problem_size_.stride_h; - +#if 0 bool check = (filter_r_ < problem_size_.R); filter_r_ = check ? filter_r_ : start_r_; next_idx = check ? 1 : 2; reset_bytes += (check ? reset_bytes_s_ : reset_bytes_r_); +#else + asm volatile( + "{\n\t" + " .reg .pred %%p;\n\t" + " .reg .s64 t1;\n\t" + " setp.lt.s32 %%p, %3, %4;\n\t" + " selp.s32 %0, %3, %5, %%p;\n\t" + " selp.s32 %1, 1, 2, %%p;\n\t" + " selp.s64 t1, %6, %7, %%p;\n\t" + " add.s64 %2, %8, t1;\n\t" + "}\n" + : "=r"(filter_r_), "=r"(next_idx), "=l"(reset_bytes) + : "r"(filter_r_), "r"(problem_size_.R), "r"(start_r_), + "l"(reset_bytes_s_), "l"(reset_bytes_r_), "l"(reset_bytes)); +#endif } // offset pointers by offset_bytes pointer_ += (params_.inc_next[next_idx] - reset_bytes); - if (next_idx == 2) { + if (next_idx == 2) { filter_k_ += params_.filter_k_delta; } diff --git a/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_analytic.h index 80448f361e..1de41f3f7b 100644 --- a/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_analytic.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -528,7 +528,6 @@ class Conv2dDgradOutputGradientTileAccessIteratorAnalytic < int k = filter_k_ + iteration_vector_ * AccessType::kElements; return TensorCoord(n, p, q, k); - } /// Returns true if the current coordinate is within the output tensor Dy diff --git a/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_optimized.h index 4b1e906aea..ffa13c934d 100644 --- a/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv2d_dgrad_output_gradient_tile_access_iterator_optimized.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -230,7 +230,7 @@ class Conv2dDgradOutputGradientTileAccessIteratorOptimized < offset_p[s] = (mapped_h + problem_size_.pad_h - filter_r) / problem_size_.stride_h; offset_q[s] = (mapped_w + problem_size_.pad_w - filter_s) / problem_size_.stride_w; - // Intialize pointers for gemm_k=0 + // Initialize pointers for gemm_k=0 TensorCoord coord{offset_n[s], offset_p[s], offset_q[s], filter_k_}; pointer_[s] += params_.layout(coord) * sizeof_bits::value / 8; @@ -321,7 +321,7 @@ class Conv2dDgradOutputGradientTileAccessIteratorOptimized < add_byte_offset_(pointer_offset * sizeof_bits::value / 8); } - CUTLASS_HOST_DEVICE + CUTLASS_DEVICE void advance() { int next_idx = 0; @@ -336,23 +336,37 @@ class Conv2dDgradOutputGradientTileAccessIteratorOptimized < // Move filter_r by stride_h filter_r_ += problem_size_.stride_h; +#if 0 if (filter_r_ < problem_size_.R) { - + next_idx = 1; - // Restore bytes in q coordinate (Mma in filter s dimenstion) + // Restore bytes in q coordinate (Mma in filter s dimension) reset_bytes = reset_bytes_s_; } else { // Restore filter_r filter_r_ = start_r_; - + next_idx = 2; - - // Restore bytes in p and q coordinate (Mma in filter s and r dimenstion) + + // Restore bytes in p and q coordinate (Mma in filter s and r dimension) reset_bytes = reset_bytes_r_; } +#else + asm volatile( + "{\n\t" + " .reg .pred %%p;\n\t" + " setp.lt.s32 %%p, %3, %4;\n\t" + " selp.s32 %0, %3, %5, %%p;\n\t" + " selp.s32 %1, 1, 2, %%p;\n\t" + " selp.s64 %2, %6, %7, %%p;\n\t" + "}\n" + : "=r"(filter_r_), "=r"(next_idx), "=l"(reset_bytes) + : "r"(filter_r_), "r"(problem_size_.R), "r"(start_r_), + "l"(reset_bytes_s_), "l"(reset_bytes_r_)); +#endif } // offset pointers by offset_bytes @@ -619,7 +633,7 @@ class Conv2dDgradOutputGradientTileAccessIteratorOptimized < CUTLASS_PRAGMA_UNROLL for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { - clear_mask(v_idx, filter_k_ >= problem_size.K); + clear_mask(v_idx, filter_k_ + v_idx * AccessType::kElements >= problem_size.K); } set_iteration_index(0); diff --git a/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h index add089af91..9317ea0cd9 100644 --- a/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_analytic.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -67,7 +67,8 @@ template < typename Element_, typename Layout_, typename ThreadMap_, - typename AccessType_ = cutlass::AlignedArray + typename AccessType_ = cutlass::AlignedArray, + conv::GroupMode GroupMode_ = conv::GroupMode::kNone > class Conv2dFpropActivationTileAccessIteratorAnalytic { public: @@ -89,6 +90,7 @@ class Conv2dFpropActivationTileAccessIteratorAnalytic { static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; static int const kConvDim = 2; using ConvProblemSize = typename conv::Conv2dProblemSize; + static conv::GroupMode const kGroupMode = GroupMode_; static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; @@ -119,6 +121,11 @@ class Conv2dFpropActivationTileAccessIteratorAnalytic { int filter_c_; int filter_r_; int filter_s_; + int filter_c_init_; + int group_idx_offset_; + int channels_per_group_; + int crs_cnt_; + int crs_per_group_; int offset_n_[ThreadMap::Iterations::kStrided]; int offset_p_[ThreadMap::Iterations::kStrided]; @@ -137,6 +144,8 @@ class Conv2dFpropActivationTileAccessIteratorAnalytic { params_(params), problem_size_(problem_size), pointer_(reinterpret_cast(ptr)), + crs_cnt_(0), + group_idx_offset_(0), filter_c_(0), filter_r_(0), filter_s_(0) { @@ -145,6 +154,12 @@ class Conv2dFpropActivationTileAccessIteratorAnalytic { filter_c_ = threadblock_offset.column() + thread_coord.contiguous(); + if (kGroupMode != conv::GroupMode::kNone) { + filter_c_init_ = filter_c_; + channels_per_group_ = problem_size_.C / problem_size_.groups; + crs_per_group_ = problem_size_.S * problem_size_.R * ((channels_per_group_ + Shape::kColumn - 1) / Shape::kColumn); + } + CUTLASS_PRAGMA_UNROLL for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { int offset_npq = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; @@ -182,6 +197,10 @@ class Conv2dFpropActivationTileAccessIteratorAnalytic { CUTLASS_HOST_DEVICE void advance() { // moves to the next tile + if (kGroupMode != conv::GroupMode::kNone) { + ++crs_cnt_; + } + ++filter_s_; if (filter_s_ < problem_size_.S) { return; @@ -192,8 +211,19 @@ class Conv2dFpropActivationTileAccessIteratorAnalytic { return; } filter_r_ = 0; - - filter_c_ += Shape::kColumn * problem_size_.split_k_slices; + + if (kGroupMode == conv::GroupMode::kNone) { + filter_c_ += Shape::kColumn * problem_size_.split_k_slices; + } else { + if (crs_cnt_ == crs_per_group_) { + // moves to next group + crs_cnt_ = 0; + ++group_idx_offset_; + filter_c_ = group_idx_offset_ * channels_per_group_ + filter_c_init_; + } else { + filter_c_ += Shape::kColumn * problem_size_.split_k_slices; + } + } } /// Returns the coordinate in the activations tensor X that is currently pointed to @@ -273,7 +303,7 @@ class Conv2dFpropActivationTileAccessIteratorAnalytic { static Status can_implement(Conv2dProblemSize const &problem_size) { // check alignment constraint on iterator's contiguous dimension - if (problem_size.C % AccessType::kElements) { + if ((problem_size.C / problem_size.groups) % AccessType::kElements) { return Status::kErrorInvalidProblem; } diff --git a/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_few_channels.h b/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_few_channels.h index d95f3758d9..5a4489c017 100644 --- a/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_few_channels.h +++ b/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_few_channels.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_fixed_channels.h b/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_fixed_channels.h index a4bc830a70..3f1f2bc141 100644 --- a/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_fixed_channels.h +++ b/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_fixed_channels.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h index 147d4f1aeb..243d724b36 100644 --- a/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv2d_fprop_activation_tile_access_iterator_optimized.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -388,7 +388,7 @@ class Conv2dFpropActivationTileAccessIteratorOptimized { static Status can_implement(Conv2dProblemSize const &problem_size) { // check alignment constraint on iterator's contiguous dimension - if (problem_size.C % AccessType::kElements) { + if ((problem_size.C / problem_size.groups) % AccessType::kElements) { return Status::kErrorInvalidProblem; } diff --git a/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h index 08d3176d73..1725db5af5 100644 --- a/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -66,7 +66,9 @@ template < typename Element_, typename Layout_, typename ThreadMap_, - typename AccessType_ = cutlass::AlignedArray + typename AccessType_ = cutlass::AlignedArray, + conv::GroupMode GroupMode_ = conv::GroupMode::kNone, + bool IsDeconv_ = false > class Conv2dFpropFilterTileAccessIteratorAnalytic { public: @@ -84,16 +86,18 @@ class Conv2dFpropFilterTileAccessIteratorAnalytic { using TensorCoord = typename Layout::TensorCoord; using Index = typename Layout::Index; using LongIndex = typename Layout::LongIndex; + static bool const IsDeconv = IsDeconv_; static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; static int const kConvDim = 2; using ConvProblemSize = typename conv::Conv2dProblemSize; + static conv::GroupMode const kGroupMode = GroupMode_; static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), "Vectors implied by the thread map must be divisible by the access type."); - + // // Simplifying assertions // @@ -118,8 +122,14 @@ class Conv2dFpropFilterTileAccessIteratorAnalytic { int filter_r_; int filter_s_; int filter_c_; + int filter_c_init_; + int crs_cnt_; + int crs_per_group_; + int group_idx_offset_c_; + int channels_per_group_; int offset_k_[ThreadMap::Iterations::kStrided]; + int group_idx_offset_k_[ThreadMap::Iterations::kStrided]; public: @@ -134,6 +144,8 @@ class Conv2dFpropFilterTileAccessIteratorAnalytic { params_(params), problem_size_(problem_size), pointer_(reinterpret_cast(ptr)), + crs_cnt_(0), + group_idx_offset_c_(0), filter_r_(0), filter_s_(0), filter_c_(0) { @@ -142,9 +154,26 @@ class Conv2dFpropFilterTileAccessIteratorAnalytic { filter_c_ = threadblock_offset.row() + thread_coord.contiguous(); + auto input_channels = (IsDeconv ? problem_size_.K : problem_size_.C); + auto output_channels = (IsDeconv ? problem_size_.C : problem_size_.K); + + if (kGroupMode != conv::GroupMode::kNone) { + filter_c_init_ = filter_c_; + if (kGroupMode == conv::GroupMode::kDepthwise){ + channels_per_group_ = 1; + crs_per_group_ = problem_size_.S * problem_size_.R; + } else { + channels_per_group_ = input_channels / problem_size_.groups; + crs_per_group_ = problem_size_.S * problem_size_.R * ((channels_per_group_ + Shape::kRow - 1) / Shape::kRow); + } + } + CUTLASS_PRAGMA_UNROLL for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { offset_k_[s] = threadblock_offset.column() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; + if (kGroupMode != conv::GroupMode::kNone && kGroupMode != conv::GroupMode::kDepthwise) { + group_idx_offset_k_[s] = (thread_coord.strided() + s * ThreadMap::Delta::kStrided) / (output_channels / problem_size_.groups); + } } set_iteration_index(0); @@ -168,6 +197,10 @@ class Conv2dFpropFilterTileAccessIteratorAnalytic { CUTLASS_HOST_DEVICE void advance() { // moves to the next tile + if (kGroupMode != conv::GroupMode::kNone) { + ++crs_cnt_; + } + ++filter_s_; if (filter_s_ < problem_size_.S) { return; @@ -179,8 +212,21 @@ class Conv2dFpropFilterTileAccessIteratorAnalytic { return; } filter_r_ = 0; - - filter_c_ += Shape::kRow * problem_size_.split_k_slices; + + if (kGroupMode == conv::GroupMode::kNone) { + filter_c_ += Shape::kRow * problem_size_.split_k_slices; + } else { + if (crs_cnt_ == crs_per_group_) { + crs_cnt_ = 0; + filter_c_ = filter_c_init_; + if (kGroupMode != conv::GroupMode::kDepthwise) { + // moves to next group + ++group_idx_offset_c_; + } + } else { + filter_c_ += Shape::kRow * problem_size_.split_k_slices; + } + } } /// Returns the coordinate in the filter tensor W that is currently pointed to @@ -200,8 +246,17 @@ class Conv2dFpropFilterTileAccessIteratorAnalytic { TensorCoord coord = at(); - return coord.n() < problem_size_.K && - coord.c() < problem_size_.C; + auto input_channels = (IsDeconv ? problem_size_.K : problem_size_.C); + auto output_channels = (IsDeconv ? problem_size_.C : problem_size_.K); + + if (kGroupMode == conv::GroupMode::kNone) { + return coord.n() < output_channels && coord.c() < input_channels; + } else if (kGroupMode == conv::GroupMode::kDepthwise) { + return coord.n() < output_channels && coord.c() < 1; // channels_per_group_ is always equal to ONE. + } else { + return coord.n() < output_channels && coord.c() < channels_per_group_ && + group_idx_offset_c_ == group_idx_offset_k_[iteration_strided_]; + } } /// Returns a pointer to the vector starting at the current coordinate @@ -242,19 +297,22 @@ class Conv2dFpropFilterTileAccessIteratorAnalytic { CUTLASS_HOST_DEVICE static Status can_implement(Conv2dProblemSize const &problem_size) { + auto input_channels = (IsDeconv ? problem_size.K : problem_size.C); + auto output_channels = (IsDeconv ? problem_size.C : problem_size.K); + // check alignment constraint on iterator's contiguous dimension - if (problem_size.K % AccessType::kElements) { + if ((input_channels / problem_size.groups) % AccessType::kElements) { return Status::kErrorInvalidProblem; } if (platform::is_same>::value) { - if (problem_size.K % 32) { + if (output_channels % 32) { return Status::kErrorInvalidProblem; } } if (platform::is_same>::value) { - if (problem_size.K % 64) { + if (output_channels % 64) { return Status::kErrorInvalidProblem; } } diff --git a/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_few_channels.h b/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_few_channels.h index 7fb30ad957..a1291aa01c 100644 --- a/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_few_channels.h +++ b/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_few_channels.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_fixed_channels.h b/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_fixed_channels.h index 7b1012356f..e90d501745 100644 --- a/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_fixed_channels.h +++ b/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_fixed_channels.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without diff --git a/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h index 2f6a1243ba..4c2343c32c 100644 --- a/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_optimized.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -67,7 +67,8 @@ template < typename Element_, typename Layout_, typename ThreadMap_, - typename AccessType_ = cutlass::AlignedArray + typename AccessType_ = cutlass::AlignedArray, + bool IsDeconv_ = false > class Conv2dFpropFilterTileAccessIteratorOptimized{ public: @@ -85,6 +86,7 @@ class Conv2dFpropFilterTileAccessIteratorOptimized{ using TensorCoord = typename Layout::TensorCoord; using Index = typename Layout::Index; using LongIndex = typename Layout::LongIndex; + static bool const IsDeconv = IsDeconv_; static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; static int const kConvDim = 2; @@ -145,6 +147,7 @@ class Conv2dFpropFilterTileAccessIteratorOptimized{ uint32_t predicates_[kAccessesPerVector]; int filter_rs_; int filter_c_; + int channels_per_group_; // // Assertions @@ -175,10 +178,11 @@ class Conv2dFpropFilterTileAccessIteratorOptimized{ filter_c_ = threadblock_offset.row() + thread_coord.contiguous(); Index column = threadblock_offset.column() + thread_coord.strided(); + channels_per_group_ = (IsDeconv ? problem_size_.K : problem_size_.C) / problem_size_.groups; CUTLASS_PRAGMA_UNROLL for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - uint32_t pred = ((column + s * ThreadMap::Delta::kStrided < problem_size_.K) ? 1u : 0); + uint32_t pred = ((column + s * ThreadMap::Delta::kStrided < (IsDeconv ? problem_size_.C : problem_size_.K)) ? 1u : 0); CUTLASS_PRAGMA_UNROLL for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { @@ -188,7 +192,7 @@ class Conv2dFpropFilterTileAccessIteratorOptimized{ CUTLASS_PRAGMA_UNROLL for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { - clear_mask(v_idx, filter_c_ + v_idx * AccessType::kElements >= problem_size_.C); + clear_mask(v_idx, filter_c_ + v_idx * AccessType::kElements >= channels_per_group_); } pointer_ += ( @@ -229,7 +233,7 @@ class Conv2dFpropFilterTileAccessIteratorOptimized{ CUTLASS_PRAGMA_UNROLL for (int v_idx = 0; v_idx < kAccessesPerVector; ++v_idx) { - clear_mask(v_idx, filter_c_ + v_idx * AccessType::kElements >= problem_size_.C); + clear_mask(v_idx, filter_c_ + v_idx * AccessType::kElements >= channels_per_group_); } pointer_ += next; @@ -285,19 +289,22 @@ class Conv2dFpropFilterTileAccessIteratorOptimized{ CUTLASS_HOST_DEVICE static Status can_implement(Conv2dProblemSize const &problem_size) { + auto input_channels = (IsDeconv ? problem_size.K : problem_size.C); + auto output_channels = (IsDeconv ? problem_size.C : problem_size.K); + // check alignment constraint on iterator's contiguous dimension - if (problem_size.C % AccessType::kElements) { + if ((input_channels / problem_size.groups) % AccessType::kElements) { return Status::kErrorInvalidProblem; } if (platform::is_same>::value) { - if (problem_size.K % 32) { + if (output_channels % 32) { return Status::kErrorInvalidProblem; } } if (platform::is_same>::value) { - if (problem_size.K % 64) { + if (output_channels % 64) { return Status::kErrorInvalidProblem; } } diff --git a/include/cutlass/conv/threadblock/conv2d_params.h b/include/cutlass/conv/threadblock/conv2d_params.h index 1ba9532cea..d34bc9faf1 100644 --- a/include/cutlass/conv/threadblock/conv2d_params.h +++ b/include/cutlass/conv/threadblock/conv2d_params.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -554,20 +554,20 @@ struct Conv2dDgradOutputGradientIteratorOptimizedParams { // next S inc_next[0] = conv_sign * ( - layout.stride()[0] * problem_size.dilation_w + (int64_t)layout.stride()[0] * problem_size.dilation_w ) * element_size_bits / 8; // next R inc_next[1] = conv_sign * ( - layout.stride()[1] * problem_size.dilation_h - - (problem_size.S - 1) * layout.stride()[0] * problem_size.dilation_w + (int64_t)layout.stride()[1] * problem_size.dilation_h + - (problem_size.S - 1) * (int64_t)layout.stride()[0] * problem_size.dilation_w ) * element_size_bits / 8; // next K inc_next[2] = ( threadblock_shape.column() * problem_size.split_k_slices - - conv_sign * (problem_size.R - 1) * layout.stride()[1] * problem_size.dilation_h - - conv_sign * (problem_size.S - 1) * layout.stride()[0] * problem_size.dilation_w + - conv_sign * (problem_size.R - 1) * (int64_t)layout.stride()[1] * problem_size.dilation_h + - conv_sign * (problem_size.S - 1) * (int64_t)layout.stride()[0] * problem_size.dilation_w ) * element_size_bits / 8; // logical offset added to internal channel counter - units are elements, not bytes @@ -614,12 +614,12 @@ struct Conv2dStridedDgradOutputGradientIteratorOptimizedParams { // next S inc_next[0] = conv_sign * ( - layout.stride()[0] * problem_size.dilation_w + (int64_t)layout.stride()[0] * problem_size.dilation_w ) * element_size_bits / 8; // next R inc_next[1] = conv_sign * ( - layout.stride()[1] * problem_size.dilation_h + (int64_t)layout.stride()[1] * problem_size.dilation_h ) * element_size_bits / 8; // next K @@ -670,18 +670,18 @@ struct Conv2dDgradFilterIteratorOptimizedParams { TRACE_CONV_INITIALIZERS("conv2d_dgrad", "filter", element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); - inc_next_strided = (layout.stride()[2] * threadmap_delta.strided() * element_size_bits) / 8; + inc_next_strided = ((int64_t)layout.stride()[2] * threadmap_delta.strided() * element_size_bits) / 8; inc_next_rs = - ( layout.stride()[0] - - (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[2] + ( (int64_t)layout.stride()[0] + - (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * (int64_t)layout.stride()[2] ) * element_size_bits / 8; inc_next_k = ( - threadblock_shape.row() * problem_size.split_k_slices * layout.stride()[2] - - (problem_size.R * problem_size.S - 1) * layout.stride()[0] - - (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[2] + threadblock_shape.row() * problem_size.split_k_slices * (int64_t)layout.stride()[2] + - (problem_size.R * problem_size.S - 1) * (int64_t)layout.stride()[0] + - (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * (int64_t)layout.stride()[2] ) * element_size_bits / 8; filter_k_delta = threadblock_shape.row() * problem_size.split_k_slices; @@ -730,26 +730,26 @@ struct Conv2dStridedDgradFilterIteratorOptimizedParams { // next S inc_next[0] = - ( layout.stride()[0] * problem_size.stride_w + ( (int64_t)layout.stride()[0] * problem_size.stride_w //- (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[2] ) * element_size_bits / 8; // next R inc_next[1] = - ( layout.stride()[1] * problem_size.stride_h + ( (int64_t)layout.stride()[1] * problem_size.stride_h //- (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[2] ) * element_size_bits / 8; // next K inc_next[2] = ( - threadblock_shape.row() * problem_size.split_k_slices * layout.stride()[2] + threadblock_shape.row() * problem_size.split_k_slices * (int64_t)layout.stride()[2] //- (problem_size.R * problem_size.S - 1) * layout.stride()[0] //- (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[2] ) * element_size_bits / 8; // offset in units of bytes to move the pointer in backward direction - reset_bytes = (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[2] + reset_bytes = (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * (int64_t)layout.stride()[2] * element_size_bits / 8; filter_k_delta = threadblock_shape.row() * problem_size.split_k_slices; @@ -800,13 +800,13 @@ struct Conv2dWgradOutputGradientIteratorOptimizedParams { element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); // Incremental offsets in unites of bytes (number of elements) * sizeof_bits::value / 8 - offset_next_strided = (threadmap_delta.strided() * layout.stride()[0]) + offset_next_strided = (threadmap_delta.strided() * (int64_t)layout.stride()[0]) * element_size_bits / 8; offset_next_contiguous = (threadmap_delta.contiguous()) * element_size_bits / 8; - inc_next_npq = (threadblock_shape.column() * problem_size.split_k_slices * layout.stride()[0]) + inc_next_npq = (threadblock_shape.column() * problem_size.split_k_slices * (int64_t)layout.stride()[0]) * element_size_bits / 8; } }; @@ -891,4 +891,3 @@ struct PredicatedScaleBiasVectorAccessIteratorParams { } // namespace cutlass ///////////////////////////////////////////////////////////////////////////////////////////////// - diff --git a/include/cutlass/conv/threadblock/conv2d_tile_iterator.h b/include/cutlass/conv/threadblock/conv2d_tile_iterator.h index 66dd75d26f..17f4594ba5 100644 --- a/include/cutlass/conv/threadblock/conv2d_tile_iterator.h +++ b/include/cutlass/conv/threadblock/conv2d_tile_iterator.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -104,6 +104,11 @@ class TileIterator { return TileAccessIterator::getParams(problem_size, layout); } + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + tile_access_iterator_.set_iteration_index(index); + } /// Adds a pointer offset in units of Element CUTLASS_HOST_DEVICE diff --git a/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_analytic.h index ec68cc89a5..3e3a4f155d 100644 --- a/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_analytic.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -195,7 +195,7 @@ class Conv2dWgradActivationTileAccessIteratorAnalytic { s = filter_s_[iteration_contiguous_]; } else { - /// Multiple access to support non-128b alignment in contiguous dimenstion + /// Multiple access to support non-128b alignment in contiguous dimension c = (filter_c_[iteration_contiguous_] + iteration_vector_ * AccessType::kElements) % problem_size_.C; int wrap_c = (filter_c_[iteration_contiguous_] + iteration_vector_ * AccessType::kElements) / problem_size_.C; s = (filter_s_[iteration_contiguous_] + wrap_c) % problem_size_.S; @@ -268,7 +268,7 @@ class Conv2dWgradActivationTileAccessIteratorAnalytic { static Status can_implement(Conv2dProblemSize const &problem_size) { // check alignment constraint on iterator's contiguous dimension - if (problem_size.K % AccessType::kElements) { + if (problem_size.C % AccessType::kElements) { return Status::kErrorInvalidProblem; } diff --git a/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_optimized.h index fec9dcda1c..8cbcc3d9fb 100644 --- a/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv2d_wgrad_activation_tile_access_iterator_optimized.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -212,7 +212,7 @@ class Conv2dWgradActivationTileAccessIteratorOptimized { if (kAccessesPerVector > 1) { // This code section is only to support non-128b alignment - // Multiple access to support non-128b alignment in contiguous dimenstion + // Multiple access to support non-128b alignment in contiguous dimension int wrap_c; params_.c_divmod(wrap_c, c, c + iteration_vector_ * AccessType::kElements); @@ -304,7 +304,7 @@ class Conv2dWgradActivationTileAccessIteratorOptimized { static Status can_implement(Conv2dProblemSize const &problem_size) { // check alignment constraint on iterator's contiguous dimension - if (problem_size.K % AccessType::kElements) { + if (problem_size.C % AccessType::kElements) { return Status::kErrorInvalidProblem; } diff --git a/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_analytic.h index 0a26d646c7..793649dbea 100644 --- a/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_analytic.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -243,7 +243,7 @@ class Conv2dWgradOutputGradientTileAccessIteratorAnalytic { static Status can_implement(Conv2dProblemSize const &problem_size) { // check alignment constraint on iterator's contiguous dimension - if (problem_size.C % AccessType::kElements) { + if (problem_size.K % AccessType::kElements) { return Status::kErrorInvalidProblem; } diff --git a/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_optimized.h index aac0e3c3c1..07233d8924 100644 --- a/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv2d_wgrad_output_gradient_tile_access_iterator_optimized.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -293,7 +293,7 @@ class Conv2dWgradOutputGradientTileAccessIteratorOptimized { static Status can_implement(Conv2dProblemSize const &problem_size) { // check alignment constraint on iterator's contiguous dimension - if (problem_size.C % AccessType::kElements) { + if (problem_size.K % AccessType::kElements) { return Status::kErrorInvalidProblem; } diff --git a/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_analytic.h index 331132c759..943ab88cfc 100644 --- a/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_analytic.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -251,7 +251,7 @@ class Conv3dDgradFilterTileAccessIteratorAnalytic { static Status can_implement(Conv3dProblemSize const &problem_size) { // check alignment constraint on iterator's contiguous dimension - if (problem_size.C % (128/sizeof_bits::value)) { + if (problem_size.C % AccessType::kElements) { return Status::kErrorInvalidProblem; } diff --git a/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_optimized.h index a5a760db23..2d5837dd3d 100644 --- a/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv3d_dgrad_filter_tile_access_iterator_optimized.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -272,7 +272,7 @@ class Conv3dDgradFilterTileAccessIteratorOptimized { static Status can_implement(Conv3dProblemSize const &problem_size) { // check alignment constraint on iterator's contiguous dimension - if (problem_size.C % (128/sizeof_bits::value)) { + if (problem_size.C % AccessType::kElements) { return Status::kErrorInvalidProblem; } diff --git a/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_analytic.h index 0fc623e2ef..30b7f2fcf6 100644 --- a/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_analytic.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -325,7 +325,7 @@ class Conv3dDgradOutputGradientTileAccessIteratorAnalytic < static Status can_implement(ConvProblemSize const &problem_size) { // check alignment constraint on iterator's contiguous dimension - if (problem_size.K % (128/sizeof_bits::value)) { + if (problem_size.K % AccessType::kElements) { return Status::kErrorInvalidProblem; } diff --git a/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_optimized.h index 3f6b36e9dd..5a53c8cbd5 100644 --- a/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv3d_dgrad_output_gradient_tile_access_iterator_optimized.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -466,7 +466,7 @@ class Conv3dDgradOutputGradientTileAccessIteratorOptimized { } // check alignment constraint on iterator's contiguous dimension - if (problem_size.K % (128/sizeof_bits::value)) { + if (problem_size.K % AccessType::kElements) { return Status::kErrorNotSupported; } diff --git a/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_analytic.h index a192bfd4be..f0f9a86a34 100644 --- a/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_analytic.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -272,7 +272,7 @@ class Conv3dFpropActivationTileAccessIteratorAnalytic { static Status can_implement(ConvProblemSize const &problem_size) { // check alignment constraint on iterator's contiguous dimension - if (problem_size.C % (128/sizeof_bits::value)) { + if (problem_size.C % AccessType::kElements) { return Status::kErrorInvalidProblem; } diff --git a/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h index d0e89bd68b..78b270eb9a 100644 --- a/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv3d_fprop_activation_tile_access_iterator_optimized.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -455,7 +455,7 @@ class Conv3dFpropActivationTileAccessIteratorOptimized { static Status can_implement(Conv3dProblemSize const &problem_size) { // check alignment constraint on iterator's contiguous dimension - if (problem_size.C % (128/sizeof_bits::value)) { + if (problem_size.C % AccessType::kElements) { return Status::kErrorInvalidProblem; } diff --git a/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h index 8b388b0e8c..9f04adc40b 100644 --- a/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_analytic.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -64,7 +64,8 @@ namespace threadblock { template < typename Shape_, typename Element_, - typename ThreadMap_ + typename ThreadMap_, + bool IsDeconv_ = false > class Conv3dFpropFilterTileAccessIteratorAnalytic { public: @@ -82,6 +83,7 @@ class Conv3dFpropFilterTileAccessIteratorAnalytic { using TensorCoord = typename Layout::TensorCoord; using Index = typename Layout::Index; using LongIndex = typename Layout::LongIndex; + static bool const IsDeconv = IsDeconv_; static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kAnalytic; static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; static int const kConvDim = 3; @@ -198,8 +200,11 @@ class Conv3dFpropFilterTileAccessIteratorAnalytic { TensorCoord coord = at(); - return coord.n() < problem_size_.K && - coord.c() < problem_size_.C; + auto input_channels = (IsDeconv ? problem_size_.K : problem_size_.C); + auto output_channels = (IsDeconv ? problem_size_.C : problem_size_.K); + + return coord.n() < output_channels && + coord.c() < input_channels; } /// Returns a pointer to the vector starting at the current coordinate @@ -233,9 +238,10 @@ class Conv3dFpropFilterTileAccessIteratorAnalytic { /// Determines whether the Implicit GEMM can execute the given problem. CUTLASS_HOST_DEVICE static Status can_implement(ConvProblemSize const &problem_size) { - + auto input_channels = (IsDeconv ? problem_size.K : problem_size.C); + auto output_channels = (IsDeconv ? problem_size.C : problem_size.K); // check alignment constraint on iterator's contiguous dimension - if (problem_size.K % (128/sizeof_bits::value)) { + if (input_channels % AccessType::kElements) { return Status::kErrorInvalidProblem; } return Status::kSuccess; diff --git a/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h index 9b69dbcb43..efe34497f5 100644 --- a/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv3d_fprop_filter_tile_access_iterator_optimized.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -66,7 +66,8 @@ template < typename Shape_, typename Element_, typename Layout_, - typename ThreadMap_ + typename ThreadMap_, + bool IsDeconv_ = false > class Conv3dFpropFilterTileAccessIteratorOptimized{ public: @@ -84,6 +85,7 @@ class Conv3dFpropFilterTileAccessIteratorOptimized{ using TensorCoord = typename Layout::TensorCoord; using Index = typename Layout::Index; using LongIndex = typename Layout::LongIndex; + static bool const IsDeconv = IsDeconv_; static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; static int const kConvDim = 3; @@ -172,11 +174,11 @@ class Conv3dFpropFilterTileAccessIteratorOptimized{ CUTLASS_PRAGMA_UNROLL for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { - uint32_t pred = ((column + s * ThreadMap::Delta::kStrided < problem_size_.K) ? 1u : 0); + uint32_t pred = ((column + s * ThreadMap::Delta::kStrided < (IsDeconv ? problem_size_.C : problem_size_.K)) ? 1u : 0); predicates_ |= (pred << s); } - if (filter_c_ >= problem_size.C) { + if (filter_c_ >= (IsDeconv ? problem_size_.K : problem_size_.C)) { predicates_ = 0u; } @@ -214,7 +216,7 @@ class Conv3dFpropFilterTileAccessIteratorOptimized{ filter_c_ += params_.filter_c_delta; } - if (filter_c_ >= problem_size_.C) { + if (filter_c_ >= (IsDeconv ? problem_size_.K : problem_size_.C)) { predicates_ = 0; } @@ -258,12 +260,12 @@ class Conv3dFpropFilterTileAccessIteratorOptimized{ /// Determines whether the Implicit GEMM can execute the given problem. CUTLASS_HOST_DEVICE static Status can_implement(Conv3dProblemSize const &problem_size) { + auto input_channels = (IsDeconv ? problem_size.K : problem_size.C); // check alignment constraint on iterator's contiguous dimension - if (problem_size.C % (128/sizeof_bits::value)) { + if (input_channels % AccessType::kElements) { return Status::kErrorInvalidProblem; } - return Status::kSuccess; } }; diff --git a/include/cutlass/conv/threadblock/conv3d_params.h b/include/cutlass/conv/threadblock/conv3d_params.h index 5ad1e4fa3d..ac422b8f05 100644 --- a/include/cutlass/conv/threadblock/conv3d_params.h +++ b/include/cutlass/conv/threadblock/conv3d_params.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -304,8 +304,8 @@ struct Conv3dDgradOutputGradientIteratorOptimizedParams { // logical offset added to internal channel counter - units are elements, not bytes filter_k_delta = threadblock_shape.column() * problem_size.split_k_slices; } - }; + ///////////////////////////////////////////////////////////////////////////////////////////////// /// Parameters object for Conv2d DGRAD Filter (w) iterator @@ -343,18 +343,18 @@ struct Conv3dDgradFilterIteratorOptimizedParams { TRACE_CONV_INITIALIZERS("conv3d_dgrad", "filter", element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); - inc_next_strided = (layout.stride()[3] * threadmap_delta.strided() * element_size_bits) / 8; + inc_next_strided = ((int64_t)layout.stride()[3] * threadmap_delta.strided() * element_size_bits) / 8; inc_next_trs = - ( layout.stride()[0] - - (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[3] + ( (int64_t)layout.stride()[0] + - (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * (int64_t)layout.stride()[3] ) * element_size_bits / 8; inc_next_k = ( - threadblock_shape.row() * problem_size.split_k_slices * layout.stride()[3] - - (problem_size.T * problem_size.R * problem_size.S - 1) * layout.stride()[0] - - (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * layout.stride()[3] + threadblock_shape.row() * problem_size.split_k_slices * (int64_t)layout.stride()[3] + - (problem_size.T * problem_size.R * problem_size.S - 1) * (int64_t)layout.stride()[0] + - (threadmap_iterations.strided() - 1) * threadmap_delta.strided() * (int64_t)layout.stride()[3] ) * element_size_bits / 8; filter_k_delta = threadblock_shape.row() * problem_size.split_k_slices; @@ -408,13 +408,13 @@ struct Conv3dWgradOutputGradientIteratorOptimizedParams { element_size_bits, threadblock_shape, thread_count, access_size, threadmap_iterations, threadmap_delta); // Incremental offsets in unites of bytes (number of elements) * element_size_bits / 8 - offset_next_strided = (threadmap_delta.strided() * layout.stride()[0]) + offset_next_strided = (threadmap_delta.strided() * (int64_t)layout.stride()[0]) * element_size_bits / 8; offset_next_contiguous = (threadmap_delta.contiguous()) * element_size_bits / 8; - inc_next_nzpq = (threadblock_shape.column() * problem_size.split_k_slices * layout.stride()[0]) + inc_next_nzpq = (threadblock_shape.column() * problem_size.split_k_slices * (int64_t)layout.stride()[0]) * element_size_bits / 8; // Precompute several quantities for fast modulo arithmetic. diff --git a/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_analytic.h index ebf7e84ed8..cc8faea701 100644 --- a/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_analytic.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -270,7 +270,7 @@ class Conv3dWgradActivationTileAccessIteratorAnalytic { static Status can_implement(Conv3dProblemSize const &problem_size) { // check alignment constraint on iterator's contiguous dimension - if (problem_size.K % (128/sizeof_bits::value)) { + if (problem_size.C % AccessType::kElements) { return Status::kErrorInvalidProblem; } diff --git a/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_optimized.h index 0413c7dc75..2b10d207fa 100644 --- a/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv3d_wgrad_activation_tile_access_iterator_optimized.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -250,7 +250,7 @@ class Conv3dWgradActivationTileAccessIteratorOptimized { fast_divmod(p, q, residual, problem_size_.Q, params_.q_mul, params_.q_shr); int d = z * problem_size_.stride_d + precomputed_filter_t_[iteration_contiguous_]; - int h = p * problem_size_.stride_h + precomputed_filter_r_[iteration_contiguous_];; + int h = p * problem_size_.stride_h + precomputed_filter_r_[iteration_contiguous_]; int w = q * problem_size_.stride_w + precomputed_filter_s_[iteration_contiguous_]; return TensorCoord(n, d, h, w, filter_c_[iteration_contiguous_]); @@ -300,7 +300,7 @@ class Conv3dWgradActivationTileAccessIteratorOptimized { static Status can_implement(Conv3dProblemSize const &problem_size) { // check alignment constraint on iterator's contiguous dimension - if (problem_size.K % (128/sizeof_bits::value)) { + if (problem_size.C % AccessType::kElements) { return Status::kErrorInvalidProblem; } diff --git a/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_analytic.h b/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_analytic.h index 4d57f0ba48..be9d4fb7ac 100644 --- a/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_analytic.h +++ b/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_analytic.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -248,7 +248,7 @@ class Conv3dWgradOutputGradientTileAccessIteratorAnalytic { static Status can_implement(Conv3dProblemSize const &problem_size) { // check alignment constraint on iterator's contiguous dimension - if (problem_size.C % (128/sizeof_bits::value)) { + if (problem_size.K % AccessType::kElements) { return Status::kErrorInvalidProblem; } diff --git a/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_optimized.h b/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_optimized.h index 2bc5971bd6..0ef145f19d 100644 --- a/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_optimized.h +++ b/include/cutlass/conv/threadblock/conv3d_wgrad_output_gradient_tile_access_iterator_optimized.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -291,7 +291,7 @@ class Conv3dWgradOutputGradientTileAccessIteratorOptimized { static Status can_implement(Conv3dProblemSize const &problem_size) { // check alignment constraint on iterator's contiguous dimension - if (problem_size.C % (128/sizeof_bits::value)) { + if (problem_size.K % AccessType::kElements) { return Status::kErrorInvalidProblem; } diff --git a/include/cutlass/conv/threadblock/depthwise_direct_conv_params.h b/include/cutlass/conv/threadblock/depthwise_direct_conv_params.h new file mode 100644 index 0000000000..8023183499 --- /dev/null +++ b/include/cutlass/conv/threadblock/depthwise_direct_conv_params.h @@ -0,0 +1,230 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! + \file + \brief Extracts the host-params objects into non-template code. +*/ + +#pragma once + +#define TRACE_CONV_PARAMS_INITIALIZERS_ENABLED 0 + +#include "cutlass/cutlass.h" +#include "cutlass/fast_math.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" + +#if TRACE_CONV_PARAMS_INITIALIZERS_ENABLED +#include +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Parameters structure used for DepthwiseFpropActivationDirect2dConvTileAccessIteratorOptimized +template +struct Depthwise2dFpropDirectConvParams; + +/// Parameters structure used for DepthwiseFpropActivationDirect2dConvTileAccessIteratorFixedStrideDilation +template +struct Depthwise2dFpropDirectConvActivationIteratorFixedStrideDilationParams; + +/// Parameters structure used for DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized +template +struct Depthwise2dFpropDirectConvFilterIteratorParams; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Parameters structure used for DepthwiseFpropActivationDirect2dConvTileAccessIteratorOptimized +template<> +struct Depthwise2dFpropDirectConvParams { + + using Layout = layout::TensorNHWC; + + Layout layout; + + int32_t activation_tile_h; + int32_t activation_tile_w; + int32_t activation_tile_hw; + FastDivmod activation_tile_w_divmod; + + int filter[2]; + int stride[2]; + int dilation[2]; + int inc_next[2]; + FastDivmod pq_divmod; + FastDivmod q_divmod; + + int activation_load_count; + int activation_storage_elements; + int activation_size; + // + // Methods + // + + CUTLASS_HOST_DEVICE + Depthwise2dFpropDirectConvParams() { } + + CUTLASS_HOST_DEVICE + Depthwise2dFpropDirectConvParams( + Conv2dProblemSize const &problem_size, + Layout const &layout, ///< layout object + MatrixCoord threadblock_shape, ///< CTA threadblock Shape + Layout::TensorCoord threadblock_output_shape, ///< Output tile Shape per threadblock + const int element_size_bits, ///< bits of activation element + const int thread_count, ///< threads per threadblock + const int thread_count_contiguous, ///< number of threads for continuous dimension + const int element_per_load) ///< element per each load + : layout(layout) { + + filter[0] = problem_size.S; + filter[1] = problem_size.R; + + stride[0] = problem_size.stride_w; + stride[1] = problem_size.stride_h; + + dilation[0] = problem_size.dilation_w; + dilation[1] = problem_size.dilation_h; + + // Compute activation_tile size per threadblock because stride and dilation are runtime params. + activation_tile_h = (threadblock_output_shape.h() - 1) * problem_size.stride_h + + (problem_size.R - 1) * problem_size.dilation_h + 1; + activation_tile_w = (threadblock_output_shape.w() - 1) * problem_size.stride_w + + (problem_size.S - 1) * problem_size.dilation_w + 1; + activation_tile_hw = activation_tile_h * activation_tile_w; + + activation_tile_w_divmod = FastDivmod(activation_tile_w); + + /// Below two values could not be templatized because the stride and dilation are runtime params + activation_load_count = (thread_count_contiguous * activation_tile_hw + (thread_count - 1)) / thread_count; + activation_storage_elements = activation_load_count * element_per_load * thread_count; + activation_size = activation_storage_elements * element_size_bits / 8; + + // Fastdivmod for output P, Q + int tiles_p = + (problem_size.P + (threadblock_output_shape.h() - 1)) / (threadblock_output_shape.h()); + int tiles_q = (problem_size.Q + (threadblock_output_shape.w() - 1)) / + (threadblock_output_shape.w()); + + pq_divmod = FastDivmod(tiles_p * tiles_q); + q_divmod = FastDivmod(tiles_q); + + // next S + inc_next[0] = problem_size.dilation_w; + // next R + inc_next[1] = (activation_tile_w * problem_size.dilation_h - (problem_size.S - 1) * problem_size.dilation_w); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Parameters structure used for DepthwiseFpropActivationDirect2dConvTileAccessIteratorFixedStrideDilation +template <> +struct Depthwise2dFpropDirectConvActivationIteratorFixedStrideDilationParams { + using Layout = layout::TensorNHWC; + + Layout layout; + + FastDivmod pq_divmod; + FastDivmod q_divmod; + + int activation_size; + + // + // Methods + // + + CUTLASS_HOST_DEVICE + Depthwise2dFpropDirectConvActivationIteratorFixedStrideDilationParams() {} + + CUTLASS_HOST_DEVICE + Depthwise2dFpropDirectConvActivationIteratorFixedStrideDilationParams( + Conv2dProblemSize const &problem_size, + Layout const &layout, ///< Layout object + MatrixCoord threadblock_shape, ///< Threadblock Shape + Layout::TensorCoord threadblock_output_shape, ///< Output tile Shape per threadblock + const int activation_size_ ///< Activation size loaded by iterator + ) + : layout(layout), + activation_size(activation_size_) { + // Fastdivmod for output P, Q + int tiles_p = + (problem_size.P + (threadblock_output_shape.h() - 1)) / (threadblock_output_shape.h()); + int tiles_q = + (problem_size.Q + (threadblock_output_shape.w() - 1)) / (threadblock_output_shape.w()); + + pq_divmod = FastDivmod(tiles_p * tiles_q); + q_divmod = FastDivmod(tiles_q); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Parameters structure used for DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized +template <> +struct Depthwise2dFpropDirectConvFilterIteratorParams { + using Layout = layout::TensorNHWC; + + Layout layout; + + int filter_size; + + bool is_convolution; + // + // Methods + // + + CUTLASS_HOST_DEVICE + Depthwise2dFpropDirectConvFilterIteratorParams() {} + + CUTLASS_HOST_DEVICE + Depthwise2dFpropDirectConvFilterIteratorParams( + Conv2dProblemSize const &problem_size, + Layout const &layout, ///< Layout object + MatrixCoord threadblock_shape, ///< Threadblock Shape + const int filter_size_) ///< Filter size loaded by iterator + : layout(layout), + filter_size(filter_size_), + is_convolution(problem_size.mode == Mode::kConvolution){} +}; + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_fixed_stride_dilation.h b/include/cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_fixed_stride_dilation.h new file mode 100644 index 0000000000..192d961051 --- /dev/null +++ b/include/cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_fixed_stride_dilation.h @@ -0,0 +1,314 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM A (activation tile) + matrix from memory. + + This iterator assumes TensorNHWC layout of tensors in Global Memory. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/threadblock/depthwise_direct_conv_params.h" +#include "cutlass/coord.h" +#include "cutlass/cutlass.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template > +class DepthwiseFpropActivationDirect2dConvTileAccessIteratorFixedStrideDilation { + public: + // + // Types + // + + using Shape = Shape_; + using OutputTileShape = OutputTileShape_; + using Element = Element_; + using Layout = Layout_; + using TensorCoord = typename Layout::TensorCoord; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + using TensorRef = cutlass::TensorRef; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; + static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; + static int const kConvDim = 2; + using ConvProblemSize = typename conv::Conv2dProblemSize; + + // Compilation value of stride , dialtion and activation shape + using StrideShape = StrideShape_; + using DilationShape = DilationShape_; + using ActivationShape = ActivationShape_; + + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + static int const kActivationSize = ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess * ThreadMap::kThreads * + sizeof_bits::value / 8; + + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + // + // Simplifying assertions + // + static_assert(ThreadMap::Iterations::kContiguous == 1, "Require Iterations::kContiguous == 1"); + + static_assert(OutputTileShape::kN == 1, "Require OutputTileShape::kN == 1"); + static_assert(OutputTileShape::kC == Shape::kColumn, "Require OutputTile shape == channels per threadblock"); + + // + // Parameters structure + // + + using Params = Depthwise2dFpropDirectConvActivationIteratorFixedStrideDilationParams; + + private: + Conv2dProblemSize const &problem_size_; + Params const ¶ms_; + char const *pointer_; + + // Base channels for current threadblock + int base_c_; + // Base activation index for current threadblock + int offset_intial_npq_; + // Base activation coord for current threadblock + TensorCoord activatioin_base_; + // Intial thread positioin + int offset_initial_hwc_; + // Overall load instruction per thread. + int iterator_load_; + // thread loading position. + int iterator_hwc_; + // activation N is inside the Tensor or not + bool valid_n_; + + public: + + + CUTLASS_HOST_DEVICE + DepthwiseFpropActivationDirect2dConvTileAccessIteratorFixedStrideDilation( + Params const ¶ms, + Conv2dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = + MatrixCoord() + ) + : params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)), + offset_intial_npq_(threadblock_offset.row()), + offset_initial_hwc_(thread_idx), + iterator_load_(0) { + + base_c_ = threadblock_offset.column(); + + set_iteration_index(0); + + set_activation_coord(offset_intial_npq_); + + } + + CUTLASS_HOST_DEVICE + void set_activation_coord(int offset_npq) { + int offset_inital_n, offset_inital_p, offset_inital_q; + int residual; + + params_.pq_divmod(offset_inital_n, residual, offset_npq); + params_.q_divmod(offset_inital_p, offset_inital_q, residual); + + int base_n = offset_inital_n; + + int base_h = + offset_inital_p * OutputTileShape::kH * StrideShape::kRow - problem_size_.pad_h; + + int base_w = + offset_inital_q * OutputTileShape::kW * StrideShape::kColumn - problem_size_.pad_w; + + activatioin_base_ = TensorCoord(base_n, base_h, base_w, base_c_); + + valid_n_ = activatioin_base_.n() < problem_size_.N; + } + + CUTLASS_HOST_DEVICE + static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { + return Params( + problem_size, + layout, + {Shape::kRow, Shape::kColumn}, + {OutputTileShape::kN, OutputTileShape::kH, OutputTileShape::kW, OutputTileShape::kC}, + kActivationSize); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iterator_hwc_ = offset_initial_hwc_ + index * ThreadMap::kThreads; + iterator_load_ = index; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_HOST_DEVICE + void advance() { + // Go to next threadblock + offset_intial_npq_ += problem_size_.split_k_slices; + + set_iteration_index(0); + + set_activation_coord(offset_intial_npq_); + } + + /// Returns the coordinate in the activations tensor X that is currently pointed to + /// by the iterator. + CUTLASS_HOST_DEVICE + TensorCoord at() const { + int c = iterator_hwc_ % ThreadMap::Detail::ShapeVec::kContiguous ; + int next = iterator_hwc_ / ThreadMap::Detail::ShapeVec::kContiguous ; + int h = next / ActivationShape::kW; + int w = next % ActivationShape::kW; + + c = c * AccessType::kElements; + + return activatioin_base_ + TensorCoord(0, h, w, c); + } + + /// Returns true if the current coordinate is within the activations tensor X + CUTLASS_HOST_DEVICE + bool valid() const { + TensorCoord coord = at(); + bool valid_c = coord.c() < problem_size_.C; + bool valid_h = coord.h() >= 0 && coord.h() < problem_size_.H; + bool valid_w = coord.w() >= 0 && coord.w() < problem_size_.W; + return valid_n_ ? valid_c & valid_h & valid_w : 0; + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + TensorCoord coord = at(); + LongIndex offset = params_.layout(coord); + + AccessType const *ptr = + reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); + + return ptr; + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + DepthwiseFpropActivationDirect2dConvTileAccessIteratorFixedStrideDilation &operator++() { + + ++iterator_load_; + iterator_hwc_ += ThreadMap::kThreads; + + if (iterator_load_ < ThreadMap::Iterations::kCount) { + return *this; + } + + iterator_load_ = 0; + iterator_hwc_ = offset_initial_hwc_; + + return *this; + } + + /// Determines the activation size loaded by iterator + CUTLASS_HOST_DEVICE + int get_load_size() { + return kActivationSize; + } + + /// Determines the iterations needed + CUTLASS_HOST_DEVICE + int get_iteration_num() { + return ThreadMap::Iterations::kCount; + } + + /// Determines whether the Depthwise fprop can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv2dProblemSize const &problem_size) { + + // check stride and dilation constraint + if (problem_size.stride_h != StrideShape::kRow || problem_size.stride_w != StrideShape::kColumn) { + return Status::kErrorInvalidProblem; + } + + if (problem_size.dilation_h != DilationShape::kRow || problem_size.dilation_w != DilationShape::kColumn) { + return Status::kErrorInvalidProblem; + } + + // check alignment constraint on iterator's contiguous dimension + if (problem_size.C % AccessType::kElements) { + return Status::kErrorInvalidProblem; + } + + return Status::kSuccess; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_optimized.h b/include/cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_optimized.h new file mode 100644 index 0000000000..a858a23f9f --- /dev/null +++ b/include/cutlass/conv/threadblock/depthwise_fprop_activation_tile_access_iterator_direct_conv_optimized.h @@ -0,0 +1,291 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM A (activation tile) + matrix from memory. + + This iterator assumes TensorNHWC layout of tensors in Global Memory. +*/ + +#pragma once + +#include "cutlass/array.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/threadblock/depthwise_direct_conv_params.h" +#include "cutlass/coord.h" +#include "cutlass/cutlass.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template > +class DepthwiseFpropActivationDirect2dConvTileAccessIteratorOptimized { + public: + // + // Types + // + + using Shape = Shape_; + using OutputTileShape = OutputTileShape_; + using Element = Element_; + using Layout = Layout_; + using TensorCoord = typename Layout::TensorCoord; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + using TensorRef = cutlass::TensorRef; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; + static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; + static int const kConvDim = 2; + using ConvProblemSize = typename conv::Conv2dProblemSize; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + // + // Simplifying assertions + // + static_assert(ThreadMap::Iterations::kContiguous == 1, "Require Iterations::kContiguous == 1"); + + static_assert(OutputTileShape::kN == 1, "Require OutputTileShape::kN == 1"); + static_assert(OutputTileShape::kC == Shape::kColumn, "Require OutputTile shape == channels per threadblock"); + + // + // Parameters structure + // + + using Params = Depthwise2dFpropDirectConvParams; + + private: + Conv2dProblemSize const &problem_size_; + Params const ¶ms_; + char const *pointer_; + + // Base channels for current threadblock + int base_c_; + // Base activation index for current threadblock + int offset_intial_npq_; + // Base activation coord for current threadblock + TensorCoord activatioin_base_; + // Intial thread positioin + int offset_initial_hwc_; + // Overall load instruction per thread. + int iterator_load_; + // thread loading position. + int iterator_hwc_; + // Number of loads for activations tensor X. + const int number_of_loads_; + + public: + + + CUTLASS_HOST_DEVICE + DepthwiseFpropActivationDirect2dConvTileAccessIteratorOptimized( + Params const ¶ms, + Conv2dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = + MatrixCoord() + ) + : params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)), + offset_intial_npq_(threadblock_offset.row()), + offset_initial_hwc_(thread_idx), + iterator_load_(0), + number_of_loads_(params.activation_load_count) { + + base_c_ = threadblock_offset.column(); + + set_activation_coord(offset_intial_npq_); + + set_iteration_index(0); + } + + CUTLASS_HOST_DEVICE + void set_activation_coord(int offset_npq) { + int offset_inital_n, offset_inital_p, offset_inital_q; + int residual; + + params_.pq_divmod(offset_inital_n, residual, offset_npq); + params_.q_divmod(offset_inital_p, offset_inital_q, residual); + + int base_n = offset_inital_n; + + int base_h = + offset_inital_p * OutputTileShape::kH * problem_size_.stride_h - problem_size_.pad_h; + + int base_w = + offset_inital_q * OutputTileShape::kW * problem_size_.stride_w - problem_size_.pad_w; + + activatioin_base_ = TensorCoord(base_n, base_h, base_w, base_c_); + } + + CUTLASS_HOST_DEVICE + static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { + return Params( + problem_size, + layout, + {Shape::kRow, Shape::kColumn}, + {OutputTileShape::kN, OutputTileShape::kH, OutputTileShape::kW, OutputTileShape::kC}, + sizeof_bits::value, + ThreadMap::kThreads, + ThreadMap::Detail::ShapeVec::kContiguous, + ThreadMap::kElementsPerAccess); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iterator_hwc_ = offset_initial_hwc_ + index * ThreadMap::kThreads; + iterator_load_ = index; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += pointer_offset * sizeof_bits::value / 8; + } + + CUTLASS_HOST_DEVICE + void advance() { + // Go to next threadblock + offset_intial_npq_ += problem_size_.split_k_slices; + + set_activation_coord(offset_intial_npq_); + } + + /// Returns the coordinate in the activations tensor X that is currently pointed to + /// by the iterator. + CUTLASS_HOST_DEVICE + TensorCoord at() const { + + int c = iterator_hwc_ % ThreadMap::Detail::ShapeVec::kContiguous ; + int next = iterator_hwc_ / ThreadMap::Detail::ShapeVec::kContiguous ; + int h, w; + params_.activation_tile_w_divmod(h, w, next) ; + + c = c * AccessType::kElements; + + return activatioin_base_ + TensorCoord(0, h, w, c); + } + + /// Returns true if the current coordinate is within the activations tensor X + CUTLASS_HOST_DEVICE + bool valid() const { + TensorCoord coord = at(); + + return coord.n() < problem_size_.N && coord.h() >= 0 && coord.h() < problem_size_.H && + coord.w() >= 0 && coord.w() < problem_size_.W && coord.c() < problem_size_.C; + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + TensorCoord coord = at(); + LongIndex offset = params_.layout(coord); + + AccessType const *ptr = + reinterpret_cast(pointer_ + offset * sizeof_bits::value / 8); + + return ptr; + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + DepthwiseFpropActivationDirect2dConvTileAccessIteratorOptimized &operator++() { + + ++iterator_load_; + iterator_hwc_ += ThreadMap::kThreads; + + if (iterator_load_ < number_of_loads_) { + return *this; + } + + iterator_load_ = 0; + iterator_hwc_ = offset_initial_hwc_; + + return *this; + } + + /// Determines the activation size loaded by iterator + CUTLASS_HOST_DEVICE + int get_load_size() { + return params_.activation_size; + } + + /// Determines the iterations needed + CUTLASS_HOST_DEVICE + int get_iteration_num() { + return number_of_loads_; + } + + /// Determines whether the Depthwise fprop can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv2dProblemSize const &problem_size) { + // check alignment constraint on iterator's contiguous dimension + if (problem_size.C % AccessType::kElements) { + return Status::kErrorInvalidProblem; + } + + return Status::kSuccess; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/threadblock/depthwise_fprop_direct_conv_multistage.h b/include/cutlass/conv/threadblock/depthwise_fprop_direct_conv_multistage.h new file mode 100644 index 0000000000..50aeee006d --- /dev/null +++ b/include/cutlass/conv/threadblock/depthwise_fprop_direct_conv_multistage.h @@ -0,0 +1,551 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a multistage threadblock-scoped Implicit GEMM Convolution kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" +#include "cutlass/arch/cache_operation.h" +#include "cutlass/conv/threadblock/depthwise_mma_base.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Cache operation for operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | + // MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Cache operation for operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Epilogue stores the data into global memory + typename Epilogue_, + /// iterator implementation variants + conv::IteratorAlgorithm IteratorAlgorithm_ = conv::IteratorAlgorithm::kOptimized, + /// Used for partial specialization + typename Enable = bool> +class DepthwiseFpropDirectConvMultipleStage : + public DepthwiseDirectConvMmaBase { +public: + ///< Base class + using Base = DepthwiseDirectConvMmaBase; + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + ///< Iterates over tiles of A operand in global memory + using IteratorA = IteratorA_; + ///< Iterates over tiles of B operand in global memory + using IteratorB = IteratorB_; + ///< Policy describing tuning details + using Policy = Policy_; + + using Epilogue = Epilogue_; + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + + static cutlass::arch::CacheOperation::Kind const kCacheOpA = CacheOpA; + static cutlass::arch::CacheOperation::Kind const kCacheOpB = CacheOpB; + + static conv::IteratorAlgorithm const kItertorAlgorithm = IteratorAlgorithm_; + + // + // Dependent types + // + + /// Fragment of accumulator tile + + using ElementC = typename Policy::Operator::ElementC; + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Internal structure exposed for introspection. + struct Detail { + + /// Number of cp.async instructions to load one stage of operand A + static int const AsyncCopyIterationsPerStageA = + IteratorA::ThreadMap::Iterations::kCount; + + /// Number of cp.async instructions to load one stage of operand B + static int const AsyncCopyIterationsPerStageB = + IteratorB::ThreadMap::Iterations::kCount; + + /// Number of stages + static int const kStages = Stages; + + /// Number of cp.async instructions to load on group of operand B + static int const kAccessesPerGroupB = + (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + }; + + private: + + using WarpLoadedFragmentA = typename Operator::FragmentA; + using WarpLoadedFragmentB = typename Operator::FragmentB; + using WarpTransformedFragmentA = typename Operator::TransformedFragmentA; + using WarpTransformedFragmentB = typename Operator::TransformedFragmentB; + + private: + + // + // Data members + // + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + DepthwiseFpropDirectConvMultipleStage( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + typename Base::SharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx + ): + Base(shared_storage, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) + { + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset( + {warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset( + {Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + + CUTLASS_DEVICE + void copy_tiles_and_advance(IteratorA &iterator_A, + IteratorB &iterator_B, + int group_start_A = 0, + int group_start_B = 0) { + if (kItertorAlgorithm == conv::IteratorAlgorithm::kFixedStrideDilation) { + // Number of iterators is a static value. + iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast(this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + } + ++this->smem_iterator_A_; + } + } else { + // Number of iterators is a runtime value. + iterator_A.set_iteration_index(group_start_A * IteratorA::kAccessesPerVector); + this->smem_iterator_A_.set_iteration_index(group_start_A); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < iterator_A.get_iteration_num(); ++j) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast(this->smem_iterator_A_.get()); + + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + } + ++this->smem_iterator_A_; + } + } + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + ///< problem size of GEMM + int gemm_k_iterations, + ///< destination accumulator tile + FragmentC &accum, + ///< iterator over A operand in global memory + IteratorA &iterator_A, + ///< Params of global memory iterator + typename IteratorA::Params const &iterator_a_params, + ///< iterator over B operand in global memory + IteratorB &iterator_B, + ///< Params of global memory iterator + typename IteratorB::Params const &iterator_b_params, + ///< initial value of accumulator + FragmentC const &src_accum, + /// Epilogue + Epilogue &epilogue, + ///< Output operator + typename Epilogue::OutputOp const &output_op, + ///< Tile iterator for destination + typename Epilogue::OutputTileIterator &destination_iterator, + ///< Threadblock tile coordinate in GEMM (in units of threadblock tiles) + typename Epilogue::OutputTileIterator &source_iterator, + + int split_k_slices = 1 + ) { + + // + // Prologue + // + + // Issue several complete stages + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < Base::kStages - 1; ++stage, --gemm_k_iterations) { + + if (stage == 0) { + iterator_B.set_iteration_index(0); + this->smem_iterator_B_.set_iteration_index(0); + + // Async Copy for operand B + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageB; ++j) { + typename IteratorB::AccessType *dst_ptr = + reinterpret_cast(this->smem_iterator_B_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { + int const kSrcBytes = sizeof_bits::value * + IteratorB::ThreadMap::kElementsPerAccess / + IteratorB::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_B.get(), iterator_B.valid()); + + ++iterator_B; + } + + ++this->smem_iterator_B_; + } + } + + if(kItertorAlgorithm == conv::IteratorAlgorithm::kFixedStrideDilation){ + // Number of iterators is compilation static. + iterator_A.set_iteration_index(0); + this->smem_iterator_A_.set_iteration_index(0); + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < Detail::AsyncCopyIterationsPerStageA; ++j) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast(this->smem_iterator_A_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + + } else { + // Number of iterators is a runtime value. + iterator_A.set_iteration_index(0); + this->smem_iterator_A_.set_iteration_num(iterator_A.get_iteration_num()); + this->smem_iterator_A_.set_iteration_index(0); + + + // Async Copy for operand A + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < iterator_A.get_iteration_num(); ++j) { + typename IteratorA::AccessType *dst_ptr = + reinterpret_cast(this->smem_iterator_A_.get()); + + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + int const kSrcBytes = sizeof_bits::value * + IteratorA::ThreadMap::kElementsPerAccess / + IteratorA::kAccessesPerVector / 8; + + cutlass::arch::cp_async_zfill( + dst_ptr + v, iterator_A.get(), iterator_A.valid()); + + ++iterator_A; + } + + ++this->smem_iterator_A_; + } + } + + // Move to the next stage + iterator_A.advance(); + + this->smem_iterator_A_.add_tile_offset({1, 0}); + + // Inserts a fence to group cp.async instructions into stages. + cutlass::arch::cp_async_fence(); + } + + ///////////////////////////////////////////////////////////////////////////// + // Waits until kStages-2 stages have committed. + cutlass::arch::cp_async_wait(); + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math + // instructions + WarpLoadedFragmentA warp_loaded_frag_A[2]; + WarpLoadedFragmentB warp_loaded_frag_B[2]; + WarpTransformedFragmentA warp_transformed_frag_A[2]; + WarpTransformedFragmentB warp_transformed_frag_B[2]; + + Operator warp_mma; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.setup_initial_status(iterator_a_params); + + + this->warp_tile_iterator_A_.load(warp_loaded_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_loaded_frag_B[0]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + int smem_write_stage_idx = Base::kStages - 1; + int smem_read_stage_idx = 0; + + warp_mma.transform(warp_transformed_frag_A[0], warp_transformed_frag_B[0], + warp_loaded_frag_A[0], warp_loaded_frag_B[0]); + + // + // Mainloop + // + + unsigned int iterations = 0; + constexpr int inner_loop_iterations = round_up(Base::kWarpGemmIterations, 2); + + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > (-Base::kStages + 1);) { // Each iteration is a cta tile. + + accum.clear(); + + // + // Loop over GEMM K dimension + // + + // Computes a warp-level GEMM on data held in shared memory + // Each "warp_mma_k" refers to a warp-level matrix multiply-accumulate + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < inner_loop_iterations; ++warp_mma_k) { + if (Base::kWarpGemmIterations % 2 == 0 || warp_mma_k + 1 != Base::kWarpGemmIterations) { + // Load warp-level tiles from shared memory, wrapping to k offset if + // this is the last group as the case may be. + + this->warp_tile_iterator_A_.set_kgroup_index((warp_mma_k + 1) % Shape::kK); + this->warp_tile_iterator_B_.set_kgroup_index((warp_mma_k + 1) % Shape::kK); + + this->warp_tile_iterator_A_.load(warp_loaded_frag_A[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B_.load(warp_loaded_frag_B[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + } + + if (warp_mma_k > 0) + warp_mma.transform(warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B[warp_mma_k % 2], + warp_loaded_frag_A[warp_mma_k % 2], + warp_loaded_frag_B[warp_mma_k % 2]); + + // Issue global->shared copies for the next stage + int group_start_iteration_A, group_start_iteration_B; + + if (warp_mma_k == 0) { + group_start_iteration_A = 0; + group_start_iteration_B = 0; + copy_tiles_and_advance( + iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B); + } + + if (warp_mma_k < Base::kWarpGemmIterations) { + warp_mma( + accum, + warp_transformed_frag_A[warp_mma_k % 2], + warp_transformed_frag_B[warp_mma_k % 2], + accum + ); + } + + if (warp_mma_k + 1 == inner_loop_iterations) + warp_mma.transform(warp_transformed_frag_A[(warp_mma_k + 1) % 2], + warp_transformed_frag_B[(warp_mma_k + 1) % 2], + warp_loaded_frag_A[(warp_mma_k + 1) % 2], + warp_loaded_frag_B[(warp_mma_k + 1) % 2]); + + if (warp_mma_k + 2 == inner_loop_iterations) { + // Inserts a fence to group cp.async instructions into stages. + cutlass::arch::cp_async_fence(); + + // Waits until kStages-2 stages of cp.async have committed + arch::cp_async_wait(); + __syncthreads(); + + // Move to the next cta + iterator_A.advance(); + + this->smem_iterator_A_.add_tile_offset({1, 0}); + + // Add negative offsets to return iterators to the 'start' of the + // circular buffer in shared memory + if (smem_write_stage_idx == (Base::kStages - 1)) { + this->smem_iterator_A_.add_tile_offset({-Base::kStages, 0}); + + smem_write_stage_idx = 0; + } else { + ++smem_write_stage_idx; + } + + if (smem_read_stage_idx == (Base::kStages - 1)) { + this->warp_tile_iterator_A_.advance(- (Base::kStages-1) * iterator_A.get_load_size()); + smem_read_stage_idx = 0; + } else { + this->warp_tile_iterator_A_.advance(iterator_A.get_load_size()); + ++smem_read_stage_idx; + } + + if (kItertorAlgorithm == conv::IteratorAlgorithm::kFixedStrideDilation) { + this->warp_tile_iterator_A_.setup_initial_status(iterator_a_params); + } + + // goback to start position. B has no multiple stage + this->warp_tile_iterator_B_.add_tile_offset({-Policy::kPartitionsK * Shape::kK, 0}); + + --gemm_k_iterations; + } + } + + // + // Epilogue + // + int32_t smem_base_offset = iterator_B.get_load_size() + (iterations % Base::kStages) * iterator_A.get_load_size(); + + destination_iterator.set_tile_index(iterations * split_k_slices); + + source_iterator.set_tile_index(iterations * split_k_slices); + + epilogue(output_op, destination_iterator, accum, source_iterator, smem_base_offset); + + ++iterations; + } + + // Insert fence and wait for all outstanding cp.async operations to commit. + cutlass::arch::cp_async_fence(); + cutlass::arch::cp_async_wait<0>(); + __syncthreads(); + + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/threadblock/depthwise_fprop_filter_tile_access_iterator_direct_conv_optimized.h b/include/cutlass/conv/threadblock/depthwise_fprop_filter_tile_access_iterator_direct_conv_optimized.h new file mode 100644 index 0000000000..52d604e43c --- /dev/null +++ b/include/cutlass/conv/threadblock/depthwise_fprop_filter_tile_access_iterator_direct_conv_optimized.h @@ -0,0 +1,261 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing loading of convolution tiles mapped to GEMM B (filter tile) + matrix from memory. + + This iterator assumes TensorNHWC layout of tensors in Global Memory. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/coord.h" +#include "cutlass/predicate_vector.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/tensor_view.h" +#include "cutlass/layout/pitch_linear.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/layout/matrix.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/conv2d_problem_size.h" +#include "cutlass/conv/threadblock/conv2d_params.h" +#include "cutlass/conv/threadblock/conv2d_fprop_filter_tile_access_iterator_analytic.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +template > +class DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized { +public: + // + // Types + // + + using Shape = Shape_; + using Element = Element_; + using Layout = Layout_; + using ThreadMap = ThreadMap_; + using AccessType = AccessType_; + using TensorRef = cutlass::TensorRef; + using TensorCoord = typename Layout::TensorCoord; + using Index = typename Layout::Index; + using LongIndex = typename Layout::LongIndex; + static IteratorAlgorithm const kIteratorAlgorithm = conv::IteratorAlgorithm::kOptimized; + static StrideSupport const kStrideSupport = conv::StrideSupport::kStrided; + static int const kConvDim = 2; + using ConvProblemSize = typename conv::Conv2dProblemSize; + + static int const kAccessesPerVector = ThreadMap::kElementsPerAccess / AccessType::kElements; + + static int const kFilterSize = ThreadMap::Iterations::kCount * ThreadMap::kElementsPerAccess * ThreadMap::kThreads * + sizeof_bits::value / 8; + + static_assert(!(ThreadMap::kElementsPerAccess % AccessType::kElements), + "Vectors implied by the thread map must be divisible by the access type."); + + // + // Simplifying assertions + // + static_assert(ThreadMap::Iterations::kContiguous == 1, + "Require Iterations::kContiguous == 1"); + + // + // Parameters structure + // + using Params = Depthwise2dFpropDirectConvFilterIteratorParams; + + protected: + + Conv2dProblemSize const &problem_size_; + Params const ¶ms_; + LongIndex iteration_contiguous_; + LongIndex iteration_strided_; + LongIndex iteration_vector_; + char const *pointer_; + + int filter_k_; + int offset_trs_[ThreadMap::Iterations::kStrided]; + +public: + + + + CUTLASS_HOST_DEVICE + DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized( + Params const ¶ms, + Conv2dProblemSize const &problem_size, + Element const *ptr, + int thread_idx, + MatrixCoord const &threadblock_offset = MatrixCoord() + ): + params_(params), + problem_size_(problem_size), + pointer_(reinterpret_cast(ptr)), + filter_k_(0) { + + layout::PitchLinearCoord thread_coord = ThreadMap::initial_offset(thread_idx); + + filter_k_ = threadblock_offset.column() + thread_coord.contiguous(); + + CUTLASS_PRAGMA_UNROLL + for (int s = 0; s < ThreadMap::Iterations::kStrided; ++s) { + offset_trs_[s] = threadblock_offset.row() + thread_coord.strided() + s * ThreadMap::Delta::kStrided; + } + + set_iteration_index(0); + } + + CUTLASS_HOST_DEVICE + static Params getParams(Conv2dProblemSize const &problem_size, Layout const &layout) { + return Params(problem_size, layout, {Shape::kRow, Shape::kColumn}, kFilterSize); + } + + /// Overrides the internal iteration index + CUTLASS_HOST_DEVICE + void set_iteration_index(Index index) { + iteration_vector_ = index % kAccessesPerVector; + int residual_access = index / kAccessesPerVector; + iteration_contiguous_ = residual_access % ThreadMap::Iterations::kContiguous; + iteration_strided_ = residual_access / ThreadMap::Iterations::kContiguous; + } + + /// Adds a pointer offset in units of Element + CUTLASS_HOST_DEVICE + void add_pointer_offset(LongIndex pointer_offset) { + pointer_ += pointer_offset * 8 / sizeof_bits::value; + } + + CUTLASS_HOST_DEVICE + void advance() { + // Do nothing because the filter is persistent in the SMEM + } + + /// Returns the coordinate in the filter tensor W that is currently pointed to + /// by the iterator. + CUTLASS_HOST_DEVICE + TensorCoord at() const { + + int k = filter_k_ + iteration_vector_ * AccessType::kElements; + int trs = offset_trs_[iteration_strided_]; + + return TensorCoord(k, trs, 0 , 0); // As a 2D-matrix + } + + /// Returns true if the current coordinate is within the activations tensor W + CUTLASS_HOST_DEVICE + bool valid() const { + + TensorCoord coord = at(); + + return coord.n() < problem_size_.K && + coord.h() < Shape::kColumn; + } + + /// Returns a pointer to the vector starting at the current coordinate + CUTLASS_HOST_DEVICE + AccessType const *get() const { + TensorCoord coord = at(); + int64_t offset = coord.n(); + if (params_.is_convolution) { + offset += (Shape::kColumn - coord.h() - 1)* problem_size_.K; + } else { + offset += coord.h() * problem_size_.K; + } + + return reinterpret_cast(pointer_ + + offset * sizeof_bits::value / 8); + } + + /// Increments to the next memory access + CUTLASS_HOST_DEVICE + DepthwiseFpropFilterDirectConvTileAccessIteratorOptimized &operator++() { + ++iteration_vector_; + if (iteration_vector_ < kAccessesPerVector) { + return *this; + } + iteration_vector_ = 0; + + ++iteration_contiguous_; + if (iteration_contiguous_ < ThreadMap::Iterations::kContiguous) { + return *this; + } + iteration_contiguous_ = 0; + + ++iteration_strided_; + if (iteration_strided_ < ThreadMap::Iterations::kStrided) { + return *this; + } + iteration_strided_ = 0; + + return *this; + } + + /// Determines the filter size loaded by iterator + CUTLASS_HOST_DEVICE + int get_load_size() { + return kFilterSize; + } + + /// Determines whether the Implicit GEMM can execute the given problem. + CUTLASS_HOST_DEVICE + static Status can_implement(Conv2dProblemSize const &problem_size) { + + // check alignment constraint on iterator's contiguous dimension + if (problem_size.K % AccessType::kElements) { + return Status::kErrorInvalidProblem; + } + + // check whether runtime filter size is same as templated filter size. + if ((problem_size.R * problem_size.S) != Shape::kColumn) { + return Status::kErrorInvalidProblem; + } + + return Status::kSuccess; + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/threadblock/depthwise_fprop_pipelined.h b/include/cutlass/conv/threadblock/depthwise_fprop_pipelined.h new file mode 100644 index 0000000000..c2825fa60d --- /dev/null +++ b/include/cutlass/conv/threadblock/depthwise_fprop_pipelined.h @@ -0,0 +1,336 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a double-buffered threadblock-scoped GEMM kernel. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/aligned_buffer.h" +#include "cutlass/numeric_conversion.h" + +#include "cutlass/numeric_types.h" +#include "cutlass/matrix_shape.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/threadblock/mma_base.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Iterates over tiles of A operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorA_, + /// Iterates over tiles of A operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorA_, + /// Iterates over tiles of B operand in global memory + // (concept: ReadableTileIterator | ForwardTileIterator | MaskedTileIterator) + typename IteratorB_, + /// Iterates over tiles of B operand in shared memory + /// (concept: WriteableTileIterator | RandomAccessTileIterator) + typename SmemIteratorB_, + /// Data type of accumulator matrix + typename ElementC_, + /// Data type of accumulator matrix + typename LayoutC_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Transformation applied to A operand + typename TransformA_ = NumericArrayConverter< + typename SmemIteratorA_::Element, + typename IteratorA_::Element, + IteratorA_::Fragment::kElements>, + /// + /// Transformation applied to A operand + typename TransformB_ = NumericArrayConverter< + typename SmemIteratorB_::Element, + typename IteratorB_::Element, + IteratorB_::Fragment::kElements>, + /// Used for partial specialization + typename Enable = bool +> +class DepthwiseFpropPipelined : public gemm::threadblock::MmaBase { +public: + + ///< Base class + using Base = gemm::threadblock::MmaBase; + + using Shape = Shape_; ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using IteratorA = IteratorA_; ///< Iterates over tiles of A operand in global memory + using IteratorB = IteratorB_; ///< Iterates over tiles of B operand in global memory + using ElementC = ElementC_; ///< Data type of accumulator matrix + using LayoutC = LayoutC_; ///< Layout of accumulator matrix + using Policy = Policy_; ///< Policy describing tuning details + + using SmemIteratorA = SmemIteratorA_; + using SmemIteratorB = SmemIteratorB_; + + using TransformA = TransformA_; + using TransformB = TransformB_; + + // + // Dependent types + // + + /// Fragment of operand A loaded from global memory + using FragmentA = typename IteratorA::Fragment; + + /// Fragment of operand B loaded from global memory + using FragmentB = typename IteratorB::Fragment; + + /// Fragment of accumulator tile + using FragmentC = typename Policy::Operator::FragmentC; + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Obtain the arch tag from the warp-level operator + using ArchTag = typename Policy::Operator::ArchTag; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = Operator::kTransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = Operator::kTransformB; + + // staticaly assert kStages for MmaPipelined is two (Double-buffered pipeline) + static_assert((Base::kStages==2), "MmaPipelined requires kStages set to value 2"); + +private: + + using WarpFragmentA = typename Operator::FragmentA; + using WarpFragmentB = typename Operator::FragmentB; + +protected: + + /// Iterator to write threadblock-scoped tile of A operand to shared memory + SmemIteratorA smem_iterator_A_; + + /// Iterator to write threadblock-scoped tile of B operand to shared memory + SmemIteratorB smem_iterator_B_; + +public: + + /// Construct from tensor references + CUTLASS_DEVICE + DepthwiseFpropPipelined( + typename Base::SharedStorage &shared_storage, ///< Shared storage needed for internal use by threadblock-scoped GEMM + int thread_idx, ///< ID within the threadblock + int warp_idx, ///< ID of warp + int lane_idx ///< ID of each thread within a warp + ): + Base(shared_storage, thread_idx, warp_idx, lane_idx), + smem_iterator_A_(shared_storage.operand_A_ref(), thread_idx), + smem_iterator_B_(shared_storage.operand_B_ref(), thread_idx) { + + // Compute warp location within threadblock tile by mapping the warp_id to + // three coordinates: + // _m: the warp's position within the threadblock along the M dimension + // _n: the warp's position within the threadblock along the N dimension + // _k: the warp's position within the threadblock along the K dimension + + int warp_idx_mn = warp_idx % (Base::WarpCount::kM * Base::WarpCount::kN); + int warp_idx_k = warp_idx / (Base::WarpCount::kM * Base::WarpCount::kN); + + int warp_idx_m = warp_idx_mn % Base::WarpCount::kM; + int warp_idx_n = warp_idx_mn / Base::WarpCount::kM; + + // Add per-warp offsets in units of warp-level tiles + this->warp_tile_iterator_A_.add_tile_offset({warp_idx_m, Base::kWarpGemmIterations * warp_idx_k}); + this->warp_tile_iterator_B_.add_tile_offset({Base::kWarpGemmIterations * warp_idx_k, warp_idx_n}); + } + + /// Perform a threadblock-scoped matrix multiply-accumulate + CUTLASS_DEVICE + void operator()( + int gemm_k_iterations, ///< number of iterations of the mainloop + FragmentC &accum, ///< destination accumulator tile + IteratorA iterator_A, ///< iterator over A operand in global memory + IteratorB iterator_B, ///< iterator over B operand in global memory + FragmentC const &src_accum, ///< source accumulator tile + int gemm_k_iterations_per_channel = 0, ///< number of iterations per channel + TransformA transform_A = TransformA(), ///< transformation applied to A fragment + TransformB transform_B = TransformB()) { ///< transformation applied to B fragment + + // + // Prologue + // + + // Perform accumulation in the 'd' output operand + accum = src_accum; + + FragmentA tb_frag_A; + FragmentB tb_frag_B; + + tb_frag_A.clear(); + tb_frag_B.clear(); + + // The last kblock is loaded in the prolog + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); + + ++iterator_A; + ++iterator_B; + + this->smem_iterator_A_.store(transform_A(tb_frag_A)); + this->smem_iterator_B_.store(transform_B(tb_frag_B)); + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + __syncthreads(); + + // Pair of fragments used to overlap shared memory loads and math instructions + WarpFragmentA warp_frag_A[2]; + WarpFragmentB warp_frag_B[2]; + + this->warp_tile_iterator_A_.set_kgroup_index(0); + this->warp_tile_iterator_B_.set_kgroup_index(0); + + this->warp_tile_iterator_A_.load(warp_frag_A[0]); + this->warp_tile_iterator_B_.load(warp_frag_B[0]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + Operator warp_mma; + + int smem_write_stage_idx = 1; + // Depthwise specific + int channel_start_index = 0; + int rs_plane_idx = 0; + + // Issue loads during the first warp-level matrix multiply-add *AFTER* issuing + // shared memory loads (which have the tightest latency requirement). + + // + // Mainloop + // + + // Note: The main loop does not support Base::kWarpGemmIterations == 2. + CUTLASS_GEMM_LOOP + for (; gemm_k_iterations > 0; --gemm_k_iterations) { + // + // Loop over GEMM K dimension + // + + if(rs_plane_idx == gemm_k_iterations_per_channel - 1){ + // Reset interation index. + iterator_B.set_iteration_index(0); + } + + CUTLASS_PRAGMA_UNROLL + for (int warp_mma_k = 0; warp_mma_k < Base::kWarpGemmIterations; ++warp_mma_k) { + + // Load warp-level tiles from shared memory, wrapping to k offset if this is the last group + // as the case may be. + + if (warp_mma_k == Base::kWarpGemmIterations - 1) { + + // Write fragments to shared memory + this->smem_iterator_A_.store(transform_A(tb_frag_A)); + + this->smem_iterator_B_.store(transform_B(tb_frag_B)); + + __syncthreads(); + + if(rs_plane_idx == gemm_k_iterations_per_channel - 1){ + // Move to next set of filter groups. + channel_start_index += Base::kWarpGemmIterations; + } + + ++this->smem_iterator_A_; + ++this->smem_iterator_B_; + + // Add negative offsets to return iterators to the 'start' of the circular buffer in shared memory + if (smem_write_stage_idx == 1) { + this->smem_iterator_A_.add_tile_offset({0, -Base::kStages}); + this->smem_iterator_B_.add_tile_offset({-Base::kStages, 0}); + } + else { + this->warp_tile_iterator_A_.add_tile_offset( + {0, -Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations}); + this->warp_tile_iterator_B_.add_tile_offset( + {-Base::kStages * Policy::kPartitionsK * Base::kWarpGemmIterations, + 0}); + } + + smem_write_stage_idx ^= 1; + } + + this->warp_tile_iterator_A_.set_kgroup_index(channel_start_index + (warp_mma_k + 1) % Base::kWarpGemmIterations); + this->warp_tile_iterator_B_.set_kgroup_index(channel_start_index + (warp_mma_k + 1) % Base::kWarpGemmIterations); + + this->warp_tile_iterator_A_.load(warp_frag_A[(warp_mma_k + 1) % 2]); + this->warp_tile_iterator_B_.load(warp_frag_B[(warp_mma_k + 1) % 2]); + + ++this->warp_tile_iterator_A_; + ++this->warp_tile_iterator_B_; + + if (warp_mma_k == 0) { + + iterator_A.load(tb_frag_A); + iterator_B.load(tb_frag_B); + + ++iterator_A; + ++iterator_B; + } + + warp_mma(accum, warp_frag_A[warp_mma_k % 2], + warp_frag_B[warp_mma_k % 2], accum); + } + + rs_plane_idx = (rs_plane_idx == gemm_k_iterations_per_channel - 1) ? 0: (rs_plane_idx + 1); + + } + + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace gemm +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/threadblock/depthwise_mma_base.h b/include/cutlass/conv/threadblock/depthwise_mma_base.h new file mode 100644 index 0000000000..967587be05 --- /dev/null +++ b/include/cutlass/conv/threadblock/depthwise_mma_base.h @@ -0,0 +1,229 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Template for a directconv threadblock-scoped Depthwise kernel. +*/ + +#pragma once + +#include "cutlass/aligned_buffer.h" +#include "cutlass/arch/memory.h" +#include "cutlass/array.h" +#include "cutlass/cutlass.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/numeric_types.h" + +//////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +//////////////////////////////////////////////////////////////////////////////// + +/// Policy object describing MmaTensorOp +template < + /// Warp-level GEMM operator (concept: gemm::warp::Mma) + typename Operator_, + /// Padding used for A operand in shared memory (concept: MatrixShape) + typename SmemPaddingA_, + /// Padding used for B operand in shared memory (concept: MatrixShape) + typename SmemPaddingB_, + /// + typename ThreadMapA_, + /// + typename ThreadMapB_, + /// Number of partitions of K dimension of GEMM + int PartitionsK = 1> +struct DepthwiseDirectConvMmaPolicy { + /// Warp-level GEMM operator (concept: gemm::warp::MmaTensorOp or gemm::warp::MmaSimt) + using Operator = Operator_; + + /// Padding used for A operand in shared memory + using SmemPaddingA = SmemPaddingA_; + + /// Padding used for B operand in shared memory + using SmemPaddingB = SmemPaddingB_; + + using ThreadMapA = ThreadMapA_; + using ThreadMapB = ThreadMapB_; + + /// Number of partitions of K dimension + static int const kPartitionsK = PartitionsK; +}; + +//////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math +/// instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Policy describing tuning details (concept: MmaPolicy) + typename Policy_, + /// Number of stages, + int Stages, + /// Used for partial specialization + typename Enable = bool> +class DepthwiseDirectConvMmaBase { + public: + ///< Size of the Gemm problem - concept: gemm::GemmShape<> + using Shape = Shape_; + + ///< Policy describing tuning details + using Policy = Policy_; + + // + // Dependent types + // + + /// Warp-level Mma + using Operator = typename Policy::Operator; + + /// Shape describing the overall GEMM computed from shared memory + /// by each warp. + using WarpGemm = typename Policy::Operator::Shape; + + /// Shape describing the number of warps filling the CTA + using WarpCount = cutlass::gemm:: + GemmShape; + + /// Number of warp-level GEMM oeprations + /// kWarpGemmIterations could be even and odd. + static int const kWarpGemmIterations = (WarpGemm::kK / Operator::Policy::MmaShape::kK); + + /// Number of stages + static int const kStages = Stages; + + /// Tensor reference to the A operand + using TensorRefA = TensorRef; + + /// Tensor reference to the B operand + using TensorRefB = TensorRef; + + static_assert(kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + // + // Nested structs + // + + /// Shared storage object needed by threadblock-scoped GEMM + class SharedStorage { + public: + // + // Type definitions + // + + /// Shape of the A matrix operand in shared memory + using ShapeA = MatrixShape<1, // Not determined at compile-time :( + Shape::kN + Policy::SmemPaddingA::kRow>; + + /// Shape of the B matrix operand in shared memory + using ShapeB = MatrixShape; // Tile N = 64? + + public: + // + // Data members + // + + // Let persistent B matrix in front of dynamic matrix A + /// Buffer for B operand + AlignedBuffer operand_B; + + /// Buffer for A operand + /// Not be determined at compile-time -- Just to get a Smem start address. + AlignedBuffer operand_A; + public: + // + // Methods + // + + /// Returns a layout object for the A matrix + CUTLASS_DEVICE + static typename Operator::LayoutA LayoutA() { + return Operator::LayoutA::packed({ShapeA::kRow, ShapeA::kColumn}); + } + + /// Returns a layout object for the B matrix + CUTLASS_HOST_DEVICE + static typename Operator::LayoutB LayoutB() { + return Operator::LayoutB::packed({ShapeB::kRow, ShapeB::kColumn}); + } + + /// Returns a TensorRef to the A operand + CUTLASS_HOST_DEVICE + TensorRefA operand_A_ref() { return TensorRefA{operand_A.data(), LayoutA()}; } + + /// Returns a TensorRef to the B operand + CUTLASS_HOST_DEVICE + TensorRefB operand_B_ref() { return TensorRefB{operand_B.data(), LayoutB()}; } + }; + + protected: + // + // Data members + // + + /// Iterator to load a warp-scoped tile of A operand from shared memory + typename Operator::IteratorA warp_tile_iterator_A_; + + /// Iterator to load a warp-scoped tile of B operand from shared memory + typename Operator::IteratorB warp_tile_iterator_B_; + + public: + /// Construct from tensor references + CUTLASS_DEVICE + DepthwiseDirectConvMmaBase( + ///< Shared storage needed for internal use by threadblock-scoped GEMM + SharedStorage &shared_storage, + ///< ID within the threadblock + int thread_idx, + ///< ID of warp + int warp_idx, + ///< ID of each thread within a warp + int lane_idx) + : warp_tile_iterator_A_(shared_storage.operand_A_ref(), lane_idx), + warp_tile_iterator_B_(shared_storage.operand_B_ref(), lane_idx) {} +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace threadblock +} // namespace conv +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/conv/threadblock/depthwise_mma_core_with_lane_access_size.h b/include/cutlass/conv/threadblock/depthwise_mma_core_with_lane_access_size.h new file mode 100644 index 0000000000..de84180f38 --- /dev/null +++ b/include/cutlass/conv/threadblock/depthwise_mma_core_with_lane_access_size.h @@ -0,0 +1,952 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Defines basic properties needed by CTA-level GEMMs assuming expectations about data + layout of the global memory fragments, data types, and internal tile sizes. + + Partial specializations for threadblock::Mma operations targeting depthwise related simt instructions. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" + +#include "cutlass/numeric_types.h" +#include "cutlass/matrix_shape.h" + +#include "cutlass/gemm/warp/mma.h" + +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/warp/mma_depthwise_simt.h" + +#include "cutlass/gemm/threadblock/mma_pipelined.h" +#include "cutlass/gemm/threadblock/mma_singlestage.h" + +#include "cutlass/gemm/threadblock/mma_base.h" +#include "cutlass/conv/threadblock/depthwise_mma_base.h" + +#include "cutlass/transform/threadblock/regular_tile_access_iterator_pitch_linear_direct_conv.h" + +#include "cutlass/arch/cache_operation.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace threadblock { + +namespace detail { +// +// Convert a WarpShapeM which is the whole tile of elements into the number of elements (2D) held by +// each partitions within warp. +// The goal is for each thread's tile of elements to be as square as +// possible for performance (4x4 will be faster than 2x8). +template // The number of partitions within the warp +struct SimtWarpShape { + // kP * kQ * WarpNumThreadsM = WarpShapeM + // If needed, enable more specializations. +}; +template <> +struct SimtWarpShape<4, 4> { + static constexpr int kP = 1; + static constexpr int kQ = 1; +}; + +template <> +struct SimtWarpShape<4, 2> { + static constexpr int kP = 2; + static constexpr int kQ = 1; +}; + +template <> +struct SimtWarpShape<4, 1> { + static constexpr int kP = 2; + static constexpr int kQ = 2; +}; + +template <> +struct SimtWarpShape<8, 1> { + static constexpr int kP = 2; + static constexpr int kQ = 4; +}; +template <> +struct SimtWarpShape<8, 2> { + static constexpr int kP = 2; + static constexpr int kQ = 2; +}; +template <> +struct SimtWarpShape<8, 4> { + static constexpr int kP = 1; + static constexpr int kQ = 2; +}; + +template <> +struct SimtWarpShape<16, 1> { + static constexpr int kP = 4; + static constexpr int kQ = 4; +}; +template <> +struct SimtWarpShape<16, 2> { + static constexpr int kP = 2; + static constexpr int kQ = 4; +}; +template <> +struct SimtWarpShape<16, 4> { + static constexpr int kP = 2; + static constexpr int kQ = 2; +}; + +template +struct SimtWarpShape<25, WarpNumThreadsM> { + static_assert(WarpNumThreadsM == 1, "WarpShapeM could not be evenly splited by threads"); + static constexpr int kP = 5; + static constexpr int kQ = 5; +}; + +template <> +struct SimtWarpShape<32, 1> { + static constexpr int kP = 4; + static constexpr int kQ = 8; +}; + +template <> +struct SimtWarpShape<32, 2> { + static constexpr int kP = 4; + static constexpr int kQ = 4; +}; + +template <> +struct SimtWarpShape<32, 4> { + static constexpr int kP = 2; + static constexpr int kQ = 4; +}; + +} // namespace detail + +template < + /// Shape of threadblock-scoped matrix multiply operator + typename Shape, + /// Shape of warp-level matrix multiply operator + typename WarpShape, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape, + /// Element data type of A operand + typename ElementA, + /// Layout of operand A + typename LayoutA, + /// Element data type of B operand + typename ElementB, + /// Layout of operand B + typename LayoutB, + /// Data type of accumulator + typename ElementC, + /// Layout of accumulator + typename LayoutC, + /// Indicates type of math operator (arch::OpClassSimt or arch::OpClassTensorOp) + typename OperatorClass, + /// Size of a warp-scoped per thread access + int kLaneAccessSizeA_ = 0, + /// Size of a warp-scoped per thread access + int kLaneAccessSizeB_ = 0, + /// Number of stages + int Stages = 2, + /// Operation performed by MMA + typename Operator = typename platform::conditional< + (platform::is_same::value) && + (platform::is_same::value || + platform::is_same::value || + platform::is_same::value || + platform::is_same::value), + cutlass::arch::OpMultiplyAddSaturate, + cutlass::arch::OpMultiplyAdd>::type, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor = false, + /// Cache operation of operand A + cutlass::arch::CacheOperation::Kind CacheOpA = + cutlass::arch::CacheOperation::Global, + /// Cache operation of operand B + cutlass::arch::CacheOperation::Kind CacheOpB = + cutlass::arch::CacheOperation::Global, + /// per-element transformation for elements of A + ComplexTransform TransformA = ComplexTransform::kNone, + /// per-element transformation for elements of B + ComplexTransform TransformB = ComplexTransform::kNone, + bool IsComplex = false // (is_complex::value || is_complex::value) +> +struct DepthwiseMmaCoreWithLaneAccessSize; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Shape of threadblock-scoped matrix multiply operator + typename Shape, + /// Shape of threadblock-scoped output tile + typename ThreadBlockOutputShape, + /// Shape of filter shape per threadblock + typename FilterShape, + /// Shape of warp-level matrix multiply operator + typename WarpShape, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape, + /// Element data type of A operand + typename ElementA, + /// Layout of operand A + typename LayoutA, + /// Element data type of B operand + typename ElementB, + /// Layout of operand B + typename LayoutB, + /// Data type of accumulator + typename ElementC, + /// Layout of accumulator + typename LayoutC, + /// Indicates type of math operator (arch::OpClassSimt or arch::OpClassTensorOp) + typename OperatorClass, + /// Size of a warp-scoped per thread access + int kLaneAccessSizeA_ = 0, + /// Size of a warp-scoped per thread access + int kLaneAccessSizeB_ = 0, + /// Number of stages + int Stages = 2, + /// Operation performed by MMA + typename Operator = typename platform::conditional< + (platform::is_same::value) && + (platform::is_same::value || + platform::is_same::value || + platform::is_same::value || + platform::is_same::value), + cutlass::arch::OpMultiplyAddSaturate, + cutlass::arch::OpMultiplyAdd>::type, + /// Iterator algo type + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic, + /// Stride ( MatrixShape ) + typename StrideShape = cutlass::MatrixShape<-1, -1>, + /// Dilation ( MatrixShape ) + typename DilationShape = cutlass::MatrixShape<-1, -1>, + /// Activation Shape loaded by threadblock + typename ActivationShape = cutlass::conv::TensorNHWCShape<-1,-1,-1,-1>, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor = false, + /// Cache operation of operand A + cutlass::arch::CacheOperation::Kind CacheOpA = + cutlass::arch::CacheOperation::Global, + /// Cache operation of operand B + cutlass::arch::CacheOperation::Kind CacheOpB = + cutlass::arch::CacheOperation::Global, + /// per-element transformation for elements of A + ComplexTransform TransformA = ComplexTransform::kNone, + /// per-element transformation for elements of B + ComplexTransform TransformB = ComplexTransform::kNone, + bool IsComplex = false // (is_complex::value || is_complex::value) +> +struct DepthwiseDirectConvMmaCoreWithLaneAccessSize; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Shape of threadblock-scoped matrix multiply operator + typename Shape, + /// Shape of warp-level matrix multiply operator + typename WarpShape, + /// Shape of one matrix production operation (concept: GemmShape) + typename InstructionShape, + /// Element data type of A operand + typename ElementA, + /// Layout of operand A + typename LayoutA, + /// Element data type of B operand + typename ElementB, + /// Layout of operand B + typename LayoutB, + /// Data type of accumulator + typename ElementC, + /// Layout of accumulator + typename LayoutC, + /// Indicates type of math operator (arch::OpClassSimt or arch::OpClassTensorOp) + typename OperatorClass, + /// Number of stages + int Stages, + /// Operation performed by MMA + typename Operator, + /// Store the accumulators in row major or column major. Row major is used + /// when output layout is interleaved. + bool AccumulatorsInRowMajor, + /// Cache operation of operand A + cutlass::arch::CacheOperation::Kind CacheOpA, + /// Cache operation of operand B + cutlass::arch::CacheOperation::Kind CacheOpB, + /// per-element transformation for elements of A + ComplexTransform TransformA, + /// per-element transformation for elements of B + ComplexTransform TransformB, + bool IsComplex +> +struct DepthwiseMmaCoreWithLaneAccessSize< + Shape, WarpShape, InstructionShape, + ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, + OperatorClass, -1, -1, Stages, Operator, AccumulatorsInRowMajor, + CacheOpA, CacheOpB, TransformA, TransformB, IsComplex +> : cutlass::gemm::threadblock::DefaultMmaCore< + Shape, WarpShape, InstructionShape, + ElementA, LayoutA, ElementB, LayoutB, ElementC, LayoutC, + OperatorClass, Stages, Operator, AccumulatorsInRowMajor, + CacheOpA, CacheOpB, TransformA, TransformB, IsComplex +> {}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization: +/// +/// A: row-major +/// B: column-major +/// Operator: simt class +/// +/// This uses the default warp-level operator given tile sizes +template < + /// Shape of threadblock-scoped matrix multiply operator (concept: + /// GemmShape) + typename Shape_, + /// Shape of warp-level matrix multiply operator (concept: GemmShape) + typename WarpShape_, + /// Data type of A operand + typename ElementA_, + /// Data type of B operand + typename ElementB_, + /// Data type of accumulator + typename ElementC_, + /// Layout of accumulator + typename LayoutC_, + /// Size of a warp-scoped per thread access (a value of -1 indicates the default) + int kLaneAccessSizeA_, + /// Size of a warp-scoped per thread access (a value of -1 indicates the default) + int kLaneAccessSizeB_, + /// Operation performed by GEMM + typename Operator_> +struct DepthwiseMmaCoreWithLaneAccessSize, + ElementA_, + layout::RowMajor, + ElementB_, + layout::ColumnMajor, + ElementC_, + LayoutC_, + arch::OpClassSimt, + kLaneAccessSizeA_, + kLaneAccessSizeB_, + 2, + Operator_> : public cutlass::gemm::threadblock::DefaultMmaCore, + ElementA_, + layout::RowMajor, + ElementB_, + layout::ColumnMajor, + ElementC_, + LayoutC_, + arch::OpClassSimt, + 2, + Operator_> { + using Base = cutlass::gemm::threadblock::DefaultMmaCore, + ElementA_, + layout::RowMajor, + ElementB_, + layout::ColumnMajor, + ElementC_, + LayoutC_, + arch::OpClassSimt, + 2, + Operator_>; + + using Shape = Shape_; + using WarpShape = WarpShape_; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using ElementA = ElementA_; + using LayoutA = layout::RowMajor; + using ElementB = ElementB_; + using LayoutB = layout::ColumnMajor; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using OperatorClass = arch::OpClassSimt; + + static int const kLaneAccessSizeA = kLaneAccessSizeA_; + static int const kLaneAccessSizeB = kLaneAccessSizeB_; + + // Divisility requirements + static_assert( kLaneAccessSizeA > 0 && kLaneAccessSizeB > 0, + "Size of a warp-scoped per thread access should be larger then ZERO" ); + + /// Default Operator + using Operator = Operator_; + + /// Number of warps present + using WarpCount = typename Base::WarpCount; + + // Divisility requirements + static_assert( + !(Shape::kM % WarpShape::kM) && + !(Shape::kN % WarpShape::kN), + "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." + ); + + /// Number of threads per warp + static int const kWarpSize = cutlass::gemm::warp::WarpSize::value; + + static int const kElementsPerAccess = 1; + + // + // Shared memory layouts + // + + using SmemLayoutA = layout::ColumnMajor; + using SmemLayoutB = layout::RowMajor; + + // + // Iterators to write to shared memory are same as base class + // + + // + // Warp-level matrix multiply operator + // + + // Define the warp-level op + static const int WarpNumThreadsM = cutlass::gemm::threadblock::detail::simt_get_warp_threads_m(); + static const int WarpNumThreadsN = kWarpSize / WarpNumThreadsM; + static const int ThreadTileM = WarpShape::kM / WarpNumThreadsM; + static const int ThreadTileN = WarpShape::kN / WarpNumThreadsN; + static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), + "WarpShape must be divisible by ThreadTile shape."); + static const int LaneLayout = ThreadTileM > 4 && ThreadTileN > 4 ? 2 : 1; + static const int numElementsA = kLaneAccessSizeA / sizeof_bits::value; + static const int numElementsB = kLaneAccessSizeB / sizeof_bits::value; + static const int LaneM = cutlass::const_min(numElementsA, ThreadTileM); + static const int LaneN = cutlass::const_min(numElementsB, ThreadTileN); + + static int const kPaddingM = cutlass::gemm::threadblock::detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits::value); + static int const kPaddingN = cutlass::gemm::threadblock::detail::simt_transpose_padding(kWarpSize, Shape::kK, sizeof_bits::value); + + static_assert(!(kPaddingM % LaneM) && !(kPaddingN % LaneN), + "Padding must be divisible by Lane"); + + // these should have max of thread tile also + using LaneMmaShape = cutlass::gemm::GemmShape< + LaneM, + LaneN, + 1>; + using Policy = cutlass::gemm::warp::MmaSimtPolicy< + cutlass::MatrixShape, // WarpShape + cutlass::layout::RowMajorInterleaved, // LaneLayout + LaneMmaShape + >; + + using MmaWarpSimt = cutlass::conv::warp::MmaDepthwiseSimt< + WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> + ElementA, /// Data type of A elements + SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) + ElementB, /// Data type of B elements + SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) + ElementC, /// Element type of C matrix + LayoutC, /// Layout of C matrix (concept: MatrixLayout) + Policy /// Policy describing warp-level MmaSimtOp (concept: MmaSimtOp policy) + >; + + /// Policy used to define MmaPipelined + using MmaPolicy = cutlass::gemm::threadblock::MmaPolicy< + MmaWarpSimt, + MatrixShape, // skew for A matrix to avoid SMEM bank conflicts + MatrixShape<0, kPaddingN>, // skew for B matrix to avoid SMEM bank conflicts + WarpCount::kK + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization: +/// +/// A: row-major +/// B: row-major +/// Operator: simt class +/// +/// This uses the default warp-level operator given tile sizes +template < + /// Shape of threadblock-scoped matrix multiply operator (concept: + /// GemmShape) + typename Shape_, + /// Shape of threadblock-scoped output tile (concept: TensorNHWCShape) + typename ThreadBlockOutputShape_, + /// Shape of filter shape per threadblock + typename FilterShape_, + /// Shape of warp-level matrix multiply operator (concept: GemmShape) + typename WarpShape_, + /// Data type of A operand + typename ElementA_, + /// Data type of B operand + typename ElementB_, + /// Data type of accumulator + typename ElementC_, + /// Layout of accumulator + typename LayoutC_, + /// Size of a warp-scoped per thread access + int kLaneAccessSizeA_, + /// Number of stages + int Stages_, + /// Operation performed by GEMM + typename Operator_> +struct DepthwiseDirectConvMmaCoreWithLaneAccessSize, + ElementA_, + layout::RowMajor, + ElementB_, + layout::ColumnMajor, + ElementC_, + LayoutC_, + arch::OpClassSimt, + kLaneAccessSizeA_, + 128, + Stages_, + Operator_> { + using Shape = Shape_; + using FilterShape = FilterShape_; + using WarpShape = WarpShape_; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using ElementA = ElementA_; + using LayoutA = layout::RowMajor; + using ElementB = ElementB_; + using LayoutB = layout::ColumnMajor; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using OperatorClass = arch::OpClassSimt; + + static int const kLaneAccessSizeB = 128; + + // Divisility requirements + static_assert( kLaneAccessSizeB > 0, + "Size of a warp-scoped per thread access should be larger then ZERO" ); + + /// Default Operator + using Operator = Operator_; + + /// Number of warps present + using WarpCount = cutlass::gemm::GemmShape< + Shape::kM / WarpShape::kM, + Shape::kN / WarpShape::kN, + 1 + >; + + // Divisility requirements + static_assert( + !(Shape::kM % WarpShape::kM) && + !(Shape::kN % WarpShape::kN), + "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." + ); + + /// Number of threads per warp + static int const kWarpSize = cutlass::gemm::warp::WarpSize::value; + + /// Number of threads total + static int const kThreads = WarpCount::kCount * kWarpSize; + + // For Gmem load + static int const kElementsPerAccessA = 128 / sizeof_bits::value; + static int const kElementsPerAccessB = 128 / sizeof_bits::value; + + // + // Shared memory layouts + // + + using SmemLayoutA = layout::RowMajor; + using SmemLayoutB = layout::RowMajor; + + + // + // Iterators to write to shared memory + // + + /// ThreadMap of iterator A + using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< + layout::PitchLinearShape, // Set kStrided = 1 because activation shape is runtime value. + kThreads, + kElementsPerAccessA + >; + + /// ThreadMap of iterator A + using SmemThreadMapA = IteratorThreadMapA; + + /// Shared memory iterator to A operand + using SmemIteratorA = transform::threadblock::RegularTileAccessIteratorDirectConv< + MatrixShape<1, Shape::kN>, // set kRow is 1 because it is a runtime value + ElementA, + SmemLayoutA, + 0, + SmemThreadMapA, // was IteratorThreadMapA + true // Dynamic iterations. + >; + + /// ThreadMap of iterator B + using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< + layout::PitchLinearShape, + kThreads, + kElementsPerAccessB + >; + + /// Transpose the ThreadMap of iterator B + using SmemThreadMapB = IteratorThreadMapB; + + /// Shared memory iterator to B operand + using SmemIteratorB = transform::threadblock::RegularTileAccessIteratorDirectConv< + MatrixShape, + ElementB, + SmemLayoutB, + 0, + SmemThreadMapB, // was IteratorThreadMapB + false // static iterations. + >; + + // + // Warp-level matrix multiply operator + // + // Groups per threads + // Fp32: 2 groups + // Fp16: 2 groups + static const int GroupsPerThread = sizeof(ElementB) > 1 ? 2 : 4; + // Define the warp-level op + static const int WarpNumThreadsN = cutlass::const_min(WarpShape::kN / GroupsPerThread, kWarpSize); + static const int WarpNumThreadsM = kWarpSize / WarpNumThreadsN; + + static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), + "WarpShape must be divisible by ThreadTile shape."); + + // Get output P, Q per thread + static const int TileP = cutlass::conv::threadblock::detail::SimtWarpShape::kP; + static const int TileQ = cutlass::conv::threadblock::detail::SimtWarpShape::kQ; + + static const int LaneLayout = 1; + static const int numElementsB = kLaneAccessSizeB / sizeof_bits::value; + static const int LaneN = cutlass::const_min(numElementsB, WarpShape::kN / WarpNumThreadsN); + + // Define the output tile computed by each thread + using ThreadOutputShape = cutlass::conv::TensorNHWCShape<1, TileP, TileQ, LaneN>; + + // Fetch the channel with same access size + static const int LaneM = LaneN; + + // No paddings + static int const kPaddingM = 0; + static int const kPaddingN = 0; + + static_assert(!(kPaddingM % LaneM) && !(kPaddingN % LaneN), + "Padding must be divisible by Lane"); + + // these should have max of thread tile also + using LaneMmaShape = cutlass::gemm::GemmShape< + LaneM, + LaneN, + 1>; + + using Policy = cutlass::gemm::warp::MmaSimtPolicy< + cutlass::MatrixShape, // WarpShape + cutlass::layout::RowMajorInterleaved, // LaneLayout + LaneMmaShape + >; + + using MmaWarpSimt = cutlass::conv::warp::MmaDepthwiseDirectConvSimt< + WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> + FilterShape, /// Shape of filter shape per threadblock - concept: gemm::GemmShape + ThreadOutputShape, /// Size of the output tile computed by thread - concept: conv::TensorNHWCShape<> + ThreadBlockOutputShape_, /// Size of the output tile computed by threadblock - concept: conv::TensorNHWCShape<> + ElementA, /// Data type of A elements + SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) + ElementB, /// Data type of B elements + SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) + ElementC, /// Element type of C matrix + LayoutC, /// Layout of C matrix (concept: MatrixLayout) + Policy /// Policy describing warp-level MmaSimtOp (concept: MmaSimtOp policy) + >; + + /// Policy used to define MmaPipelined + using MmaPolicy = cutlass::conv::threadblock::DepthwiseDirectConvMmaPolicy< + MmaWarpSimt, + MatrixShape, // skew for A matrix to avoid SMEM bank conflicts + MatrixShape<0, kPaddingN>, // skew for B matrix to avoid SMEM bank conflicts + IteratorThreadMapA, + IteratorThreadMapB, + WarpCount::kK + >; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Partial specialization: +/// +/// A: row-major +/// B: row-major +/// Operator: simt class +/// +/// This uses the default warp-level operator given tile sizes +template < + /// Shape of threadblock-scoped matrix multiply operator (concept: + /// GemmShape) + typename Shape_, + /// Shape of threadblock-scoped output tile (concept: TensorNHWCShape) + typename ThreadBlockOutputShape_, + /// Shape of filter shape per threadblock + typename FilterShape_, + /// Shape of warp-level matrix multiply operator (concept: GemmShape) + typename WarpShape_, + /// Data type of A operand + typename ElementA_, + /// Data type of B operand + typename ElementB_, + /// Data type of accumulator + typename ElementC_, + /// Layout of accumulator + typename LayoutC_, + /// Size of a warp-scoped per thread access + int kLaneAccessSizeA_, + /// Number of stages + int Stages_, + /// Operation performed by GEMM + typename Operator_, + /// Stride ( MatrixShape ) + typename StrideShape_, + /// Dilation ( MatrixShape ) + typename DilationShape_, + /// Activation Shape loaded by threadblock + typename ActivationShape_> +struct DepthwiseDirectConvMmaCoreWithLaneAccessSize, + ElementA_, + layout::RowMajor, + ElementB_, + layout::ColumnMajor, + ElementC_, + LayoutC_, + arch::OpClassSimt, + kLaneAccessSizeA_, + 128, + Stages_, + Operator_, + IteratorAlgorithm::kFixedStrideDilation, + StrideShape_, + DilationShape_, + ActivationShape_> { + using Shape = Shape_; + using FilterShape = FilterShape_; + using WarpShape = WarpShape_; + using InstructionShape = cutlass::gemm::GemmShape<1, 1, 1>; + using ElementA = ElementA_; + using LayoutA = layout::RowMajor; + using ElementB = ElementB_; + using LayoutB = layout::ColumnMajor; + using ElementC = ElementC_; + using LayoutC = LayoutC_; + using OperatorClass = arch::OpClassSimt; + using StrideShape = StrideShape_; + using DilationShape = DilationShape_; + using ThreadBlockOutputShape = ThreadBlockOutputShape_; + using ActivationShape = ActivationShape_; + + static int const kLaneAccessSizeB = 128; + + // Divisility requirements + static_assert( kLaneAccessSizeB > 0, + "Size of a warp-scoped per thread access should be larger then ZERO" ); + + /// Default Operator + using Operator = Operator_; + + /// Number of warps present + using WarpCount = cutlass::gemm::GemmShape< + Shape::kM / WarpShape::kM, + Shape::kN / WarpShape::kN, + 1 + >; + + // Divisility requirements + static_assert( + !(Shape::kM % WarpShape::kM) && + !(Shape::kN % WarpShape::kN), + "Threadblock-scoped GEMM should be divisible by warp-scoped GEMM size." + ); + + /// Number of threads per warp + static int const kWarpSize = cutlass::gemm::warp::WarpSize::value; + + /// Number of threads total + static int const kThreads = WarpCount::kCount * kWarpSize; + + // For Gmem load + static int const kElementsPerAccessA = 128 / sizeof_bits::value; + static int const kElementsPerAccessB = 128 / sizeof_bits::value; + + // + // Shared memory layouts + // + + using SmemLayoutA = layout::RowMajor; + using SmemLayoutB = layout::RowMajor; + + + // + // Iterators to write to shared memory + // + + /// ThreadMap of iterator A + using IteratorThreadMapA = transform::PitchLinearStripminedThreadMap< + layout::PitchLinearShape, + kThreads, + kElementsPerAccessA + >; + + /// ThreadMap of iterator A + using SmemThreadMapA = IteratorThreadMapA; + + /// Shared memory iterator to A operand + using SmemIteratorA = transform::threadblock::RegularTileAccessIteratorDirectConv< + MatrixShape, + ElementA, + SmemLayoutA, + 0, + SmemThreadMapA, // was IteratorThreadMapA + false // static iterations. + >; + + /// ThreadMap of iterator B + using IteratorThreadMapB = transform::PitchLinearStripminedThreadMap< + layout::PitchLinearShape, + kThreads, + kElementsPerAccessB + >; + + /// Transpose the ThreadMap of iterator B + using SmemThreadMapB = IteratorThreadMapB; + + /// Shared memory iterator to B operand + using SmemIteratorB = transform::threadblock::RegularTileAccessIteratorDirectConv< + MatrixShape, + ElementB, + SmemLayoutB, + 0, + SmemThreadMapB, // was IteratorThreadMapB + false // static iterations. + >; + + // + // Warp-level matrix multiply operator + // + // Groups per threads + // Fp32: 2 groups + // Fp16: 2 groups + static const int GroupsPerThread = sizeof(ElementB) > 1 ? 2 : 4; + // Define the warp-level op + static const int WarpNumThreadsN = cutlass::const_min(WarpShape::kN / GroupsPerThread, kWarpSize); + static const int WarpNumThreadsM = kWarpSize / WarpNumThreadsN; + + static const int TileP = cutlass::conv::threadblock::detail::SimtWarpShape::kP; + static const int TileQ = cutlass::conv::threadblock::detail::SimtWarpShape::kQ; + + static_assert(!(WarpShape::kM % WarpNumThreadsM) && !(WarpShape::kN % WarpNumThreadsN), + "WarpShape must be divisible by ThreadTile shape."); + + static const int LaneLayout = 1; + static const int numElementsB = kLaneAccessSizeB / sizeof_bits::value; + static const int LaneN = cutlass::const_min(numElementsB, WarpShape::kN / WarpNumThreadsN); + + // Define the output tile computed by each thread + using ThreadOutputShape = cutlass::conv::TensorNHWCShape<1, TileP, TileQ, LaneN>; + + // Fetch the channel with same access size + static const int LaneM = LaneN; + + // No paddings + static int const kPaddingM = 0; + static int const kPaddingN = 0; + + static_assert(!(kPaddingM % LaneM) && !(kPaddingN % LaneN), + "Padding must be divisible by Lane"); + + // these should have max of thread tile also + using LaneMmaShape = cutlass::gemm::GemmShape< + LaneM, + LaneN, + 1>; + + using Policy = cutlass::gemm::warp::MmaSimtPolicy< + cutlass::MatrixShape, // WarpShape + cutlass::layout::RowMajorInterleaved, // LaneLayout + LaneMmaShape + >; + + using MmaWarpSimt = cutlass::conv::warp::MmaDepthwiseDirectConvSimt< + WarpShape, /// Size of the Gemm problem - concept: gemm::GemmShape<> + FilterShape, /// Shape of filter shape per threadblock - concept: gemm::GemmShape + ThreadOutputShape, /// Size of the output tile computed by thread - concept: conv::TensorNHWCShape<> + ThreadBlockOutputShape, /// Size of the output tile computed by threadblock - concept: conv::TensorNHWCShape<> + ElementA, /// Data type of A elements + SmemLayoutA, /// Layout of A matrix (concept: MatrixLayout) + ElementB, /// Data type of B elements + SmemLayoutB, /// Layout of B matrix (concept: MatrixLayout) + ElementC, /// Element type of C matrix + LayoutC, /// Layout of C matrix (concept: MatrixLayout) + Policy, /// Policy describing warp-level MmaSimtOp (concept: MmaSimtOp policy) + IteratorAlgorithm::kFixedStrideDilation, /// Iterator algo type + StrideShape, /// Stride ( MatrixShape ) + DilationShape, /// Dilation ( MatrixShape ) + ActivationShape /// Activation Shape loaded by threadblock + >; + + /// Policy used to define MmaPipelined + using MmaPolicy = cutlass::conv::threadblock::DepthwiseDirectConvMmaPolicy< + MmaWarpSimt, + MatrixShape, // skew for A matrix to avoid SMEM bank conflicts + MatrixShape<0, kPaddingN>, // skew for B matrix to avoid SMEM bank conflicts + IteratorThreadMapA, + IteratorThreadMapB, + WarpCount::kK + >; +}; +} // namespace threadblock +} // namespace conv +} // namespace cutlass diff --git a/include/cutlass/conv/threadblock/implicit_gemm_fprop_fusion_multistage.h b/include/cutlass/conv/threadblock/implicit_gemm_fprop_fusion_multistage.h index 4f16b42d8c..3bee07d0ab 100644 --- a/include/cutlass/conv/threadblock/implicit_gemm_fprop_fusion_multistage.h +++ b/include/cutlass/conv/threadblock/implicit_gemm_fprop_fusion_multistage.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -64,7 +64,7 @@ #include "cutlass/arch/cache_operation.h" #include "cutlass/gemm/gemm.h" -#include "cutlass/conv/warp/conv2d_fprop_scale_bias_iterator.h" +#include "cutlass/gemm/warp/scale_bias_tile_iterator.h" #include "cutlass/conv/warp/scale_bias_relu_transform.h" ///////////////////////////////////////////////////////////////////////////////////////////////// @@ -139,6 +139,13 @@ class MmaFpropFusionBase { /// Tensor reference to the B operand using TensorRefB = TensorRef; + static_assert(kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + static_assert((kWarpGemmIterations % 2) == 0, + "Inner loop iteration must be an even number."); + // // Nested structs // @@ -319,7 +326,7 @@ class ImplicitGemmFpropFusionMultistage using Policy = Policy_; ///< Base class using Base = MmaFpropFusionBase; using SmemIteratorA = SmemIteratorA_; @@ -518,6 +525,8 @@ class ImplicitGemmFpropFusionMultistage IteratorScaleBias iterator_A_scale_bias, ///< initial value of accumulator FragmentC const &src_accum, + ///< number of iterations per channel + int gemm_k_iterations_per_channel = 0, ///< Imaginary strides used for planar-complex only - ignored here int64_t imag_stride_A = 0, int64_t imag_stride_B = 0) { diff --git a/include/cutlass/conv/threadblock/implicit_gemm_multistage.h b/include/cutlass/conv/threadblock/implicit_gemm_multistage.h index 36b41aacd4..eea7743a40 100644 --- a/include/cutlass/conv/threadblock/implicit_gemm_multistage.h +++ b/include/cutlass/conv/threadblock/implicit_gemm_multistage.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -116,10 +116,6 @@ class ImplicitGemmMultistage : /// Internal structure exposed for introspection. struct Detail { - static_assert(Base::kWarpGemmIterations > 1, - "The pipelined structure requires at least two warp-level " - "GEMM operations."); - /// Number of cp.async instructions to load one stage of operand A static int const AsyncCopyIterationsPerStageA = IteratorA::ThreadMap::Iterations::kCount; @@ -138,6 +134,12 @@ class ImplicitGemmMultistage : /// Number of cp.async instructions to load on group of operand B static int const kAccessesPerGroupB = (AsyncCopyIterationsPerStageB + Base::kWarpGemmIterations - 1) / Base::kWarpGemmIterations; + + // Optional staged-accumulation (e.g., tf32x3 kernels) for improved numerical + // accuracy, where each mainloop iteration first accumulates into a temporary + // set of freshly-cleared accumulators, which are subsequently added to the + // final accumulator set. + static bool const kStagedAccumulation = arch::detail::UseStagedAccumulation::value; }; private: @@ -272,6 +274,8 @@ class ImplicitGemmMultistage : IteratorB iterator_B, ///< initial value of accumulator FragmentC const &src_accum, + ///< number of iterations per channel + int gemm_k_iterations_per_channel = 0, ///< Imaginary strides used for planar-complex only - ignored here int64_t imag_stride_A = 0, int64_t imag_stride_B = 0) { @@ -297,7 +301,7 @@ class ImplicitGemmMultistage : CUTLASS_PRAGMA_UNROLL for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { - int const kSrcBytes = + int const kSrcBytes = sizeof_bits::value * IteratorA::ThreadMap::kElementsPerAccess / IteratorA::kAccessesPerVector / 8; @@ -322,7 +326,7 @@ class ImplicitGemmMultistage : this->smem_iterator_B_.get()); CUTLASS_PRAGMA_UNROLL - for (int v = 0; v < IteratorA::kAccessesPerVector; ++v) { + for (int v = 0; v < IteratorB::kAccessesPerVector; ++v) { int const kSrcBytes = sizeof_bits::value * IteratorB::ThreadMap::kElementsPerAccess / @@ -389,10 +393,7 @@ class ImplicitGemmMultistage : FragmentC tmp_accum; - if (platform::is_same::value - || platform::is_same::value) { + if (Detail::kStagedAccumulation) { tmp_accum.clear(); } @@ -446,10 +447,7 @@ class ImplicitGemmMultistage : copy_tiles_and_advance(iterator_A, iterator_B, group_start_iteration_A, group_start_iteration_B); - if (platform::is_same::value - || platform::is_same::value) { + if (Detail::kStagedAccumulation) { warp_mma( tmp_accum, warp_transformed_frag_A[warp_mma_k % 2], @@ -520,10 +518,7 @@ class ImplicitGemmMultistage : } - if (platform::is_same::value - || platform::is_same::value) { + if (Detail::kStagedAccumulation) { accum = plus_accum(accum, tmp_accum); } diff --git a/include/cutlass/conv/threadblock/implicit_gemm_pipelined.h b/include/cutlass/conv/threadblock/implicit_gemm_pipelined.h index f77e2e3325..79bcb78aa5 100644 --- a/include/cutlass/conv/threadblock/implicit_gemm_pipelined.h +++ b/include/cutlass/conv/threadblock/implicit_gemm_pipelined.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -188,6 +188,7 @@ class ImplicitGemmPipelined : public gemm::threadblock::MmaBase; + static_assert(kWarpGemmIterations > 1, + "The pipelined structure requires at least two warp-level " + "GEMM operations."); + + static_assert((kWarpGemmIterations % 2) == 0, + "Inner loop iteration must be an even number."); + // // Nested structs // @@ -306,10 +313,6 @@ class ImplicitGemmWgradFusionMultistage /// Internal structure exposed for introspection. struct Detail { - static_assert(Base::kWarpGemmIterations > 1, - "The pipelined structure requires at least two warp-level " - "GEMM operations."); - /// Number of cp.async instructions to load one stage of operand A static int const AsyncCopyIterationsPerStageA = IteratorA::ThreadMap::Iterations::kCount; @@ -470,6 +473,8 @@ class ImplicitGemmWgradFusionMultistage IteratorScaleBias iterator_B_scale_bias, ///< initial value of accumulator FragmentC const &src_accum, + ///< number of iterations per channel + int gemm_k_iterations_per_channel = 0, ///< Imaginary strides used for planar-complex only - ignored here int64_t imag_stride_A = 0, int64_t imag_stride_B = 0) { diff --git a/include/cutlass/conv/threadblock/predicated_scale_bias_vector_access_iterator.h b/include/cutlass/conv/threadblock/predicated_scale_bias_vector_access_iterator.h index 7d60e4b0b2..bfe9a39816 100644 --- a/include/cutlass/conv/threadblock/predicated_scale_bias_vector_access_iterator.h +++ b/include/cutlass/conv/threadblock/predicated_scale_bias_vector_access_iterator.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -113,12 +113,9 @@ class PredicatedScaleBiasVectorAccessIterator( + const_cast(scale_pointer)) + : reinterpret_cast( + const_cast(bias_pointer)); + + // Per-thread offset in logical coordinates of tensor + int thread_base = (thread_id < kThreads) ? 0 : kThreads; + + thread_offset_ = + threadblock_offset + + TensorCoord((thread_id - thread_base) * kElementsPerAccess, 0); + + set_iteration_index(0); + } + + CUTLASS_HOST_DEVICE + PredicatedScaleBiasVectorAccessIterator( + /// Precomputed parameters object + Params const ¶ms, + /// Extent of tensor + Conv3dProblemSize const &problem_size, + /// Pointer to the start of the scale vector + ConstPointer scale_pointer, + /// Pointer to the start of the bias vector + ConstPointer bias_pointer, + /// ID of each participating thread + int thread_id, + /// Initial offset of threadblock + TensorCoord const &threadblock_offset) + : params_(params), + problem_size_trs(problem_size.T * problem_size.R * problem_size.S), + problem_size_c(problem_size.C), + filter_trs_(0) { pointer_ = (thread_id < kThreads) ? reinterpret_cast( const_cast(scale_pointer)) @@ -177,6 +207,22 @@ class PredicatedScaleBiasVectorAccessIterator filter, i.e., stride={2x2} and filter={1x1}) // // * Optimization * - // Only launch CTAs in M dimenstion which contribute to a row in Dx output + // Only launch CTAs in M dimension which contribute to a row in Dx output // // // * Constraints * @@ -95,11 +95,11 @@ struct StridedDgradHorizontalThreadblockSwizzle : /// Returns the shape of the problem in units of logical tiles /// For ImplicitGemmConvolution Conv2d problem size: conv_operator(NPQK, NHWC, KRSC) CUTLASS_HOST_DEVICE - gemm::GemmCoord get_tiled_shape( + static gemm::GemmCoord get_tiled_shape( cutlass::conv::Operator conv_operator, cutlass::conv::Conv2dProblemSize const &problem_size, gemm::GemmCoord tile_size, - int split_k_slices) const { + int split_k_slices) { gemm::GemmCoord implicit_gemm_problem_size = cutlass::conv::implicit_gemm_problem_size(conv_operator, problem_size); @@ -107,7 +107,7 @@ struct StridedDgradHorizontalThreadblockSwizzle : // compute number of tiles in m dimension int tile_m = get_strided_dgrad_tile_m(problem_size, tile_size.m()); - // compute number of tiles in n dimenstion + // compute number of tiles in n dimension int tile_n = (implicit_gemm_problem_size.n() + tile_size.n() - 1) / tile_size.n(); return gemm::GemmCoord( @@ -136,11 +136,11 @@ struct StridedDgradIdentityThreadblockSwizzle : /// Returns the shape of the problem in units of logical tiles /// For ImplicitGemmConvolution Conv2d problem size: conv_operator(NPQK, NHWC, KRSC) CUTLASS_HOST_DEVICE - gemm::GemmCoord get_tiled_shape( + static gemm::GemmCoord get_tiled_shape( cutlass::conv::Operator conv_operator, cutlass::conv::Conv2dProblemSize const &problem_size, gemm::GemmCoord tile_size, - int split_k_slices) const { + int split_k_slices) { gemm::GemmCoord implicit_gemm_problem_size = cutlass::conv::implicit_gemm_problem_size(conv_operator, problem_size); @@ -148,7 +148,7 @@ struct StridedDgradIdentityThreadblockSwizzle : // compute number of tiles in m dimension int tile_m = get_strided_dgrad_tile_m(problem_size, tile_size.m()); - // compute number of tiles in n dimenstion + // compute number of tiles in n dimension int tile_n = (implicit_gemm_problem_size.n() + tile_size.n() - 1) / tile_size.n(); return gemm::GemmCoord( @@ -157,7 +157,6 @@ struct StridedDgradIdentityThreadblockSwizzle : split_k_slices); } - /// Returns the shape of the problem in units of logical tiles /// For GEMM problem size (MxNxK) (Do not use base class get_tiled_shape()) private: @@ -166,7 +165,29 @@ struct StridedDgradIdentityThreadblockSwizzle : ///////////////////////////////////////////////////////////////////////////////////////////////// +/// Threadblock swizzling function for GEMMs +template +struct DepthwiseDirect2dConvIdentityThreadblockSwizzle + : public gemm::threadblock::GemmIdentityThreadblockSwizzle { + CUTLASS_HOST_DEVICE + DepthwiseDirect2dConvIdentityThreadblockSwizzle() {} + + /// Returns the shape of the problem in units of logical tiles + CUTLASS_HOST_DEVICE + static gemm::GemmCoord get_tiled_shape(cutlass::conv::Operator conv_operator, + cutlass::conv::Conv2dProblemSize const &problem_size, + gemm::GemmCoord tile_size, + int split_k_slices) { + + gemm::GemmCoord implicit_gemm_problem_size = + cutlass::conv::implicit_gemm_problem_size(conv_operator, problem_size); + + return gemm::GemmCoord(1, + (implicit_gemm_problem_size.n() + tile_size.n() - 1) / tile_size.n(), + split_k_slices); + } +}; } // namespace threadblock -} // namespace gemm +} // namespace conv } // namespace cutlass diff --git a/include/cutlass/conv/warp/mma_depthwise_simt.h b/include/cutlass/conv/warp/mma_depthwise_simt.h new file mode 100644 index 0000000000..ed385df039 --- /dev/null +++ b/include/cutlass/conv/warp/mma_depthwise_simt.h @@ -0,0 +1,380 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Templates implementing warp-level matrix multiply-accumulate operations. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/numeric_types.h" +#include "cutlass/matrix_shape.h" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/warp/mma.h" + +#include "cutlass/gemm/thread/mma.h" +#include "cutlass/conv/convolution.h" +#include "cutlass/conv/thread/depthwise_mma.h" + + +#include "cutlass/gemm/warp/mma_simt_tile_iterator.h" +#include "cutlass/gemm/warp/mma_simt_policy.h" + +#include "cutlass/gemm/warp/mma_simt.h" +#include "cutlass/conv/warp/mma_depthwise_simt_tile_iterator.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace warp { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Data type of A elements + typename ElementA_, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA_, + /// Data type of B elements + typename ElementB_, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB_, + /// Element type of C matrix + typename ElementC_, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC_, + /// Shape of the warp in units of thread (concept: MmaSimtPolicy) + typename Policy_, + /// Number of partitions along K dimension + int PartitionsK = 1, + /// Complex transformation on operand A + ComplexTransform TransformA = ComplexTransform::kNone, + /// Complex transformation on operand B + ComplexTransform TransformB = ComplexTransform::kNone, + /// Used for partial specialization + typename Enable = bool> +class MmaDepthwiseSimt + : public cutlass::gemm::warp:: + MmaSimt { + using Base = cutlass::gemm::warp:: + MmaSimt; + +public: + /// Shape of warp-level matrix operation (concept: GemmShape) + using Shape = Shape_; + + /// Data type of multiplicand A + using ElementA = ElementA_; + + /// Layout of multiplicand A + using LayoutA = LayoutA_; + + /// Data type of multiplicand B + using ElementB = ElementB_; + + /// Layout of multiplicand B + using LayoutB = LayoutB_; + + /// Data type of accumulator matrix C + using ElementC = ElementC_; + + /// Layout of accumulator matrix C + using LayoutC = LayoutC_; + + /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) + using Policy = Policy_; + + /// Indicates class of matrix operator + using OperatorClass = arch::OpClassSimt; + + /// Hard-coded for now + using ArchTag = arch::Sm50; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = TransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = TransformB; + +public: + + /// Iterates over the B operand in memory + using IteratorB = cutlass::conv::warp::DepthwiseMmaSimtTileIterator< + MatrixShape, + cutlass::gemm::Operand::kB, + ElementB, + LayoutB, + Policy, + PartitionsK, + Shape::kK + >; + + /// Storage for B tile + using FragmentB = typename IteratorB::Fragment; + + /// Storage for transformed A tile + using TransformedFragmentB = FragmentB; + +public: + + // + // Methods + // + + /// Ctor + CUTLASS_DEVICE + MmaDepthwiseSimt():Base() {} +}; + +/// Structure to compute the matrix product targeting CUDA cores and SIMT math instructions. +template < + /// Size of the Gemm problem - concept: gemm::GemmShape<> + typename Shape_, + /// Shape of filter shape per threadblock - concept: gemm::GemmShape + typename FilterShape_, + /// Shape of the output tile computed by thread- concept: conv::TensorNHWCShape<> + typename ThreadOutputShape_, + /// Shape of the output tile computed by threadblock - concept: conv::TensorNHWCShape<> + typename ThreadBlockOutputShape_, + /// Data type of A elements + typename ElementA_, + /// Layout of A matrix (concept: MatrixLayout) + typename LayoutA_, + /// Data type of B elements + typename ElementB_, + /// Layout of B matrix (concept: MatrixLayout) + typename LayoutB_, + /// Element type of C matrix + typename ElementC_, + /// Layout of C matrix (concept: MatrixLayout) + typename LayoutC_, + /// Shape of the warp in units of thread (concept: MmaSimtPolicy) + typename Policy_, + /// Iterator algo type + conv::IteratorAlgorithm IteratorAlgorithm_ = IteratorAlgorithm::kAnalytic, + /// Stride ( MatrixShape ) + typename StrideShape_ = cutlass::MatrixShape<-1, -1>, + /// Dilation ( MatrixShape ) + typename DilationShape_ = cutlass::MatrixShape<-1, -1>, + /// Activation Shape loaded by threadblock + typename ActivationShape_ = cutlass::conv::TensorNHWCShape<-1,-1,-1,-1>, + /// Number of partitions along K dimension + int PartitionsK = 1, + /// Complex transformation on operand A + ComplexTransform TransformA = ComplexTransform::kNone, + /// Complex transformation on operand B + ComplexTransform TransformB = ComplexTransform::kNone, + /// Used for partial specialization + typename Enable = bool> +class MmaDepthwiseDirectConvSimt { + public: + /// Shape of warp-level matrix operation (concept: GemmShape) + using Shape = Shape_; + + /// Shape of filter shape per threadblock - concept: gemm::GemmShape + using FilterShape = FilterShape_; + + /// Shape of the output tile computed by thread- concept: conv::TensorNHWCShape<> + using ThreadOutputShape = ThreadOutputShape_; + + /// Shape of the output tile computed by threadblock - concept: conv::TensorNHWCShape<> + using ThreadBlockOutputShape = ThreadBlockOutputShape_; + + /// Data type of multiplicand A + using ElementA = ElementA_; + + /// Layout of multiplicand A + using LayoutA = LayoutA_; + + /// Data type of multiplicand B + using ElementB = ElementB_; + + /// Layout of multiplicand B + using LayoutB = LayoutB_; + + /// Data type of accumulator matrix C + using ElementC = ElementC_; + + /// Layout of accumulator matrix C + using LayoutC = LayoutC_; + + /// Shape of the warp in units of thread (concept: MmaLanePolicySimt) + using Policy = Policy_; + + /// Iterator algo type + static conv::IteratorAlgorithm const IteratorAlgorithm = IteratorAlgorithm_; + + /// Stride ( MatrixShape ) + using StrideShape = StrideShape_; + + /// Dilation ( MatrixShape ) + using DilationShape = DilationShape_; + + /// Activation Shape loaded by threadblock + using ActivationShape = ActivationShape_; + + /// Indicates class of matrix operator + using OperatorClass = arch::OpClassSimt; + + /// Hard-coded for now + using ArchTag = arch::Sm50; + + /// Complex transform on A operand + static ComplexTransform const kTransformA = TransformA; + + /// Complex transform on B operand + static ComplexTransform const kTransformB = TransformB; + + static constexpr bool use_dp4a = (platform::is_same< layout::ColumnMajorInterleaved<4>, LayoutA>::value || + platform::is_same< layout::RowMajorInterleaved<4>, LayoutA >::value) && + platform::is_same< ElementA, int8_t >::value && + platform::is_same< ElementB, int8_t >::value; + + using dp4a_type = typename platform::conditional< use_dp4a , int8_t, bool >::type; + + /// Thread-level matrix multiply accumulate operator + using ThreadMma = cutlass::conv::thread::DepthwiseDirectConvElementwiseInnerProduct< + cutlass::gemm::GemmShape< + Shape::kM / Policy::WarpShape::kRow, // number of output pixels proccessed per thread + Shape::kN / Policy::WarpShape::kColumn, // number of channels proccessed per thread + 1>, + ElementA, + ElementB, + ElementC, + arch::OpMultiplyAdd, + dp4a_type + >; + + /// Underlying matrix multiply operator (concept: arch::Mma) + using ArchMmaOperator = typename ThreadMma::ArchMmaOperator; + + /// Indicates math operator + using MathOperator = typename ArchMmaOperator::Operator; + + /// Shape of the underlying instruction + using InstructionShape = cutlass::gemm::GemmShape<1,1,use_dp4a ? 4 : 1>; + +public: + + /// Iterates over the A operand in memory + using IteratorA = cutlass::conv::warp::DepthwiseDirect2dConvSimtTileIterator< + MatrixShape, // per warp + FilterShape, + ThreadOutputShape, + ThreadBlockOutputShape, + cutlass::gemm::Operand::kA, + ElementA, + Policy, + IteratorAlgorithm, + StrideShape, + DilationShape, + ActivationShape, + PartitionsK, + Shape::kK + >; + + /// Storage for A tile + using FragmentA = typename IteratorA::Fragment; + + /// Storage for transformed A tile + using TransformedFragmentA = FragmentA; + + /// Iterates over the B operand in memory + using IteratorB = cutlass::gemm::warp::MmaSimtTileIterator< + MatrixShape<1, Shape::kN>, + cutlass::gemm::Operand::kB, + ElementB, + LayoutB, + Policy, + PartitionsK, + Shape::kK + >; + + /// Storage for B tile + using FragmentB = typename IteratorB::Fragment; + + /// Storage for transformed A tile + using TransformedFragmentB = FragmentB; + + /// Iterates over the C operand in memory + using IteratorC = cutlass::gemm::warp::MmaSimtTileIterator< + MatrixShape, + cutlass::gemm::Operand::kC, + ElementC, + LayoutC, + Policy + >; + + /// Storage for C tile + using FragmentC = typename ThreadMma::FragmentC; + +public: + + // + // Methods + // + + /// Ctor + CUTLASS_DEVICE + MmaDepthwiseDirectConvSimt() {} + + /// Performs a warp-level matrix multiply-accumulate operation + CUTLASS_DEVICE + void operator()( + FragmentC &d, + FragmentA a, + FragmentB b, + FragmentC const &c, int group_idx = 0) const { + + ThreadMma mma; + + mma(d, a, b, c); + } + + /// Transform the mma operands to the required types + CUTLASS_DEVICE + void transform(TransformedFragmentA &dst_A, TransformedFragmentB &dst_B, + FragmentA const &A, FragmentB const &B) const { + dst_A = A; + dst_B = B; + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace warp +} // namespace conv +} // namespace cutlass diff --git a/include/cutlass/conv/warp/mma_depthwise_simt_tile_iterator.h b/include/cutlass/conv/warp/mma_depthwise_simt_tile_iterator.h new file mode 100644 index 0000000000..26d9638bab --- /dev/null +++ b/include/cutlass/conv/warp/mma_depthwise_simt_tile_iterator.h @@ -0,0 +1,862 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Describes the lane policy used by warp-level matrix multiply operators targeting SIMT + instructions +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/tensor_ref.h" +#include "cutlass/matrix_shape.h" + +#include "cutlass/conv/convolution.h" + +#include "cutlass/arch/memory_sm75.h" + +#include "cutlass/layout/matrix.h" + +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/warp/mma_simt_policy.h" +#include "cutlass/gemm/warp/mma_simt_tile_iterator.h" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace conv { +namespace warp { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Iterates over operands to warp-level matrix multiply operations targeting SIMT instructions +/// +/// concept: MutableRandomAccessContiguousTileIteratorConcept +/// +template < + /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Operand identity + cutlass::gemm::Operand Operand, + /// Data type of A elements + typename Element_, + /// Layout of operand + typename Layout_, + /// Shape of the warp in units of thread (concept: MmaSimtPolicy) + typename Policy_, + /// Number of partitions along K dimension - used in sliced-K + int PartitionsK = 1, + /// Group Size along kPartition - used in sliced-K + int PartitionGroupSize = 1 +> +class DepthwiseMmaSimtTileIterator; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Specialization for B operands of row-major layouts +/// +/// Concept: MutableRandomAccessContiguousTileIteratorConcept +/// +template < + /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Data type of A elements + typename Element_, + /// Shape of the warp in units of thread (concept: MmaSimtPolicy) + typename Policy_, + /// Number of partitions along K dimension + int PartitionsK, + /// Group Size along kPartition - used in sliced-K + int PartitionGroupSize> +class DepthwiseMmaSimtTileIterator + : public cutlass::gemm::warp::MmaSimtTileIterator { + + using Base = cutlass::gemm::warp::MmaSimtTileIterator; + public: + /// Shape of tile to load (concept: MatrixShape) + using Shape = Shape_; + + /// Operand tag + static cutlass::gemm::Operand const kOperand = cutlass::gemm::Operand::kB; + + /// Element type + using Element = Element_; + + /// Layout of policy + using Layout = layout::RowMajor; + + /// Decomposition of elements among threads + using Policy = Policy_; + + /// TensorRef type for loading element from a tensor + using TensorRef = typename Base::TensorRef; + + /// Index type + using Index = typename TensorRef::Index; + + /// Long Index type + using LongIndex = typename TensorRef::LongIndex; + + /// Coordinate for an element in the tensor + using TensorCoord = typename TensorRef::TensorCoord; + + /// Thread-level shape of a fragment + using ThreadShape = typename Base::ThreadShape; + + /// Number of individual loads + using Iterations = typename Base::Iterations; + + /// Fragment object holding a thread's part of a tile + using Fragment = typename Base::Fragment; + + static_assert(Policy::LaneMmaShape::kN == 1, "Each thread should be 1 element per LDS along the k-dim"); + +private: + + MatrixCoord lane_offset_; + int channel_idx_; + int base_channel_idx_; + int warps_n_; + + public: + + /// Default ctor constructs null iterator + CUTLASS_HOST_DEVICE + DepthwiseMmaSimtTileIterator():Base() { } + + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + DepthwiseMmaSimtTileIterator( + TensorRef ref, + int lane_id + ) : Base(ref, lane_id) { + + // compute offset based on thread ID and lane layout + typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); + + warps_n_ = -1; + channel_idx_ = 0; + base_channel_idx_ = 0; + lane_offset_ = lane_layout.inverse(lane_id) * MatrixCoord(0, Policy::LaneMmaShape::kN); + } + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_HOST_DEVICE + DepthwiseMmaSimtTileIterator &add_tile_offset(TensorCoord const &coord) { + + if(warps_n_ == -1){ + warps_n_ = coord.column(); + } + + Base::add_tile_offset(coord); + return *this; + } + + /// Loads a fragment from memory at the location pointed to by the iterator. (vector loads) + CUTLASS_HOST_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { + Array *dst_ptr = + reinterpret_cast *>(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < Iterations::kRow; ++k) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Iterations::kColumn; ++n) { + + void const *ptr = this->ref_.data() + + this->ref_.offset({-(channel_idx_ - base_channel_idx_), + n * Policy::WarpShape::kColumn}) + + pointer_offset / Policy::LaneMmaShape::kN; + + // Base_k of a warp + Base_k of current threads. + int thread_k_base_idx = + warps_n_ * Shape::kColumn / Policy::LaneMmaShape::kN + lane_offset_.column(); + + if (channel_idx_ + k == thread_k_base_idx + n * Policy::WarpShape::kColumn) { + // Depthwise kernel would only do computation when channel == k. + // Loads an element when the current computation channel == the k corresponding to this thread. + arch::shared_load(dst_ptr[n + k * Iterations::kColumn], ptr); + } else { + // Reduce SMEM load + dst_ptr[n + k * Iterations::kColumn].fill(Element(0)); + } + } + } + } + + /// Loads a fragment from memory at the location pointed to by the iterator. + CUTLASS_HOST_DEVICE + void load(Fragment &frag) const { + load_with_pointer_offset(frag, 0); + } + + /// Notify the iterator which k-group it is currently pointing to. + /// + /// This does not advance the iterator. Rather, it overrides its internal + /// tracking with constant-valued k-group index + CUTLASS_DEVICE + void set_kgroup_index(int k_group) { + if(k_group % PartitionGroupSize == 0 && k_group != 0){ + base_channel_idx_ = k_group; + } + channel_idx_ = k_group; + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// + +template < + /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Size of filter (concept: gemm::GemmShape) + typename FilterShape_, + /// Size of the matrix to load (concept: MatrixShape) + typename ThreadOutputShape_, + /// Size of the matrix to load (concept: MatrixShape) + typename ThreadBlockOutputShape_, + /// Operand identity + cutlass::gemm::Operand Operand, + /// Data type of A elements + typename Element_, + /// Shape of the warp in units of thread (concept: MmaSimtPolicy) + typename Policy_, + /// Iterator algo type + conv::IteratorAlgorithm IteratorAlgorithm = IteratorAlgorithm::kAnalytic, + /// Stride ( MatrixShape ) + typename StrideShape = cutlass::MatrixShape<-1, -1>, + /// Dilation ( MatrixShape ) + typename DilationShape = cutlass::MatrixShape<-1, -1>, + /// Activation Shape loaded by threadblock + typename ActivationShape = cutlass::conv::TensorNHWCShape<-1,-1,-1,-1>, + /// Number of partitions along K dimension - used in sliced-K + int PartitionsK = 1, + /// Group Size along kPartition - used in sliced-K + int PartitionGroupSize = 1> +class DepthwiseDirect2dConvSimtTileIterator; + + +/// Specialization for A operands of row-major layouts +/// +/// Concept: MutableRandomAccessContiguousTileIteratorConcept +/// +template < + /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Size of filter (concept: gemm::GemmShape) + typename FilterShape_, + /// Size of the matrix to load (concept: TensorNHWC) + typename ThreadOutputShape_, + /// Size of the matrix to load (concept: TensorNHWC) + typename ThreadBlockOutputShape_, + /// Data type of A elements + typename Element_, + /// Shape of the warp in units of thread (concept: MmaSimtPolicy) + typename Policy_, + /// Iterator algo type + conv::IteratorAlgorithm IteratorAlgorithm, + /// Stride ( MatrixShape ) + typename StrideShape, + /// Dilation ( MatrixShape ) + typename DilationShape, + /// Activation Shape loaded by threadblock + typename ActivationShape, + /// Number of partitions along K dimension - used in sliced-K + int PartitionsK, + /// Group Size along kPartition - used in sliced-K + int PartitionGroupSize> +class DepthwiseDirect2dConvSimtTileIterator { + public: + /// Shape of tile to load (concept: MatrixShape) + using Shape = Shape_; + + /// Shape of filter (concept: gemm::GemmShape) + using FilterShape = FilterShape_; + + /// Shape of tile to load (concept: TensorNHWC) + using ThreadOutputShape = ThreadOutputShape_; + + /// Shape of tile to load (concept: TensorNHWC) + using ThreadBlockOutputShape = ThreadBlockOutputShape_; + + /// Operand tag + static cutlass::gemm::Operand const kOperand = cutlass::gemm::Operand::kA; + + /// Element type + using Element = Element_; + + /// Layout of policy + using Layout = layout::RowMajor; + + /// Decomposition of elements among threads + using Policy = Policy_; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + /// Index type + using Index = typename TensorRef::Index; + + /// Long Index type + using LongIndex = typename TensorRef::LongIndex; + + /// Coordinate for an element in the tensor + using TensorCoord = typename TensorRef::TensorCoord; + + // + // Derived quantities + // + + static_assert(!(Shape::kRow % Policy::WarpShape::kRow), + "The warp-level GEMM M size must be divisible by the number of threads arranged along the M dimension."); + + static_assert(Shape::kRow > 0, "Shape::kRow must be greater than zero."); + static_assert(Shape::kColumn > 0, "Shape::kColumn must be greater than zero."); + static_assert(Policy::WarpShape::kRow > 0, "Policy::WarpShape::kRow must be greater than zero."); + static_assert(Shape::kRow / Policy::WarpShape::kRow > 0, "Shape::kRow / Policy::WarpShape::kRow must be greater than zero."); + +// Thread-level shape of a fragment + using ThreadShape = MatrixShape< + ThreadOutputShape::kNHW, // Output tile shape Computed by current threads + ThreadOutputShape::kC + >; + + static_assert(!(ThreadShape::kColumn % Policy::LaneMmaShape::kN), + "Thread-level GEMM must be divisible by Policy::LaneMmaShape."); + + /// Number of individual loads + using Iterations = MatrixShape< + ThreadShape::kRow, + ThreadShape::kColumn / Policy::LaneMmaShape::kN + >; + + using ThreadTileCount = MatrixShape< + ThreadBlockOutputShape::kH / ThreadOutputShape::kH, + ThreadBlockOutputShape::kW / ThreadOutputShape::kW + >; + + /// Fragment object holding a thread's part of a tile + using Fragment = Array; + +protected: + + /// Internal reference + cutlass::TensorRef, layout::RowMajor> ref_; + + int activation_offset[ThreadOutputShape::kH][ThreadOutputShape::kW][Iterations::kColumn]; + int iterator_r_; + int iterator_s_; + int iterator_offset_; + + int inc_next_s_ ; + int inc_next_r_ ; + + MatrixCoord lane_offset_; +public: + + /// Default ctor constructs null iterator + CUTLASS_HOST_DEVICE + DepthwiseDirect2dConvSimtTileIterator() { } + + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + DepthwiseDirect2dConvSimtTileIterator( + TensorRef ref, + int lane_id + ) { + + // compute offset based on thread ID and lane layout + typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); + + // Set channel offset + lane_offset_ = lane_layout.inverse(lane_id) * MatrixCoord(0, Policy::LaneMmaShape::kN); + + ref.add_coord_offset(lane_offset_); + + ref_.reset(reinterpret_cast *>(ref.data()), + ref.stride(0) / Policy::LaneMmaShape::kN); + + iterator_r_ = 0; + iterator_s_ = 0; + iterator_offset_ = 0; + } + + /// Adds a pointer offset to internal pointer(s) to advance through memory + CUTLASS_HOST_DEVICE + DepthwiseDirect2dConvSimtTileIterator &add_pointer_offset(LongIndex offset) { + ref_.add_pointer_offset(offset); + return *this; + } + + /// Loads a fragment from memory at the location pointed to by the iterator. + template + CUTLASS_HOST_DEVICE + void setup_initial_status(Params const& params) { + + inc_next_s_ = params.inc_next[0]; + inc_next_r_ = params.inc_next[1]; + + // Get base HW offset of current threads + int threadgroup = threadIdx.x / (ThreadBlockOutputShape::kC / ThreadOutputShape::kC); + int base_p_ = + (threadgroup / (ThreadTileCount::kColumn)) * ThreadOutputShape::kH; + int base_q_ = + (threadgroup % (ThreadTileCount::kColumn)) * ThreadOutputShape::kW; + + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < ThreadOutputShape::kH; ++p) { + CUTLASS_PRAGMA_UNROLL + for (int q = 0; q < ThreadOutputShape::kW; ++q) { + CUTLASS_PRAGMA_UNROLL + for (int col = 0; col < Iterations::kColumn; ++col) { + int base_w = (base_q_ + q) * params.stride[0]; + int base_h = (base_p_ + p) * params.stride[1]; + + int offset = base_h * params.activation_tile_w + base_w; + activation_offset[p][q][col] = offset; + } + } + } + } + + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_HOST_DEVICE + DepthwiseDirect2dConvSimtTileIterator &add_tile_offset(TensorCoord const &coord) { + // Set warp row and col start + lane_offset_ = MatrixCoord({lane_offset_.row() + coord.row() * Shape::kRow, lane_offset_.column()}); + return *this; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_HOST_DEVICE + void advance(int32_t pointer_offset) { + ref_.reset(ref_.data() + pointer_offset / sizeof(Element) / Policy::LaneMmaShape::kN); + iterator_s_ = 0; + iterator_r_ = 0; + iterator_offset_ = 0; + } + + /// Advances the iterator along the advance dimension + CUTLASS_HOST_DEVICE + DepthwiseDirect2dConvSimtTileIterator &operator++() { + ++iterator_s_; + if (iterator_s_ < FilterShape::kColumn) { + iterator_offset_ += inc_next_s_; + + return *this; + } + + iterator_s_ = 0; + + ++iterator_r_; + if (iterator_r_ < FilterShape::kRow) { + iterator_offset_ += inc_next_r_; + return *this; + } + + iterator_r_ = 0; + iterator_offset_ = 0; + return *this; + } + + /// Advances the iterator along the advance dimension + CUTLASS_HOST_DEVICE + DepthwiseDirect2dConvSimtTileIterator & operator--() { + // Do nothing + return *this; + } + + /// Loads a fragment from memory at the location pointed to by the iterator. (vector loads) + CUTLASS_HOST_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { + + Array *dst_ptr = + reinterpret_cast *>(&frag); + + + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < ThreadOutputShape::kH; ++p) { + CUTLASS_PRAGMA_UNROLL + for (int q = 0; q < ThreadOutputShape::kW; ++q) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Iterations::kColumn; ++n) { + void const *ptr = ref_.data() + + ref_.offset({activation_offset[p][q][n] + (iterator_offset_), + n * Policy::WarpShape::kColumn}) + + pointer_offset / Policy::LaneMmaShape::kN; + arch::shared_load(dst_ptr[n + q + p * ThreadOutputShape::kW], ptr); + } + } + } + } + + /// Loads a fragment from memory at the location pointed to by the iterator. + CUTLASS_HOST_DEVICE + void load(Fragment &frag) const { + load_with_pointer_offset(frag, 0); + } + + /// Stores a fragment to memory at the location pointed to by the iterator + CUTLASS_HOST_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const { + // Do nothing at present. + } + + /// Stores a fragment to memory at the location pointed to by the iterator + CUTLASS_HOST_DEVICE + void store(Fragment const &frag, Index pointer_offset) const { + store_with_pointer_offset(frag, 0); + } + + CUTLASS_DEVICE + void set_kgroup_index(int k_group) { + // no operation here + } +}; + +/////////////////////////////////////////////////////////////////////////////////////////////////// +/// Specialization for A operands of row-major layouts +/// +/// Concept: MutableRandomAccessContiguousTileIteratorConcept +/// +template < + /// Size of the matrix to load (concept: MatrixShape) + typename Shape_, + /// Size of filter (concept: gemm::GemmShape) + typename FilterShape_, + /// Size of the matrix to load (concept: TensorNHWC) + typename ThreadOutputShape_, + /// Size of the matrix to load (concept: TensorNHWC) + typename ThreadBlockOutputShape_, + /// Data type of A elements + typename Element_, + /// Shape of the warp in units of thread (concept: MmaSimtPolicy) + typename Policy_, + /// Stride ( MatrixShape ) + typename StrideShape_, + /// Dilation ( MatrixShape ) + typename DilationShape_, + /// Activation Shape loaded by threadblock + typename ActivationShape_, + /// Number of partitions along K dimension - used in sliced-K + int PartitionsK, + /// Group Size along kPartition - used in sliced-K + int PartitionGroupSize> +class DepthwiseDirect2dConvSimtTileIterator { + public: + /// Shape of tile to load (concept: MatrixShape) + using Shape = Shape_; + + /// Shape of filter (concept: gemm::GemmShape) + using FilterShape = FilterShape_; + + /// Shape of tile to load (concept: TensorNHWC) + using ThreadOutputShape = ThreadOutputShape_; + + /// Shape of tile to load (concept: TensorNHWC) + using ThreadBlockOutputShape = ThreadBlockOutputShape_; + + /// Stride ( MatrixShape ) + using StrideShape = StrideShape_; + + /// Dilation ( MatrixShape ) + using DilationShape = DilationShape_; + + /// Activation Shape loaded by threadblock + using ActivationShape = ActivationShape_; + + /// Operand tag + static cutlass::gemm::Operand const kOperand = cutlass::gemm::Operand::kA; + + /// Element type + using Element = Element_; + + /// Layout of policy + using Layout = layout::RowMajor; + + /// Decomposition of elements among threads + using Policy = Policy_; + + /// TensorRef type for loading element from a tensor + using TensorRef = TensorRef; + + /// Index type + using Index = typename TensorRef::Index; + + /// Long Index type + using LongIndex = typename TensorRef::LongIndex; + + /// Coordinate for an element in the tensor + using TensorCoord = typename TensorRef::TensorCoord; + + // + // Derived quantities + // + + static_assert(!(Shape::kRow % Policy::WarpShape::kRow), + "The warp-level GEMM M size must be divisible by the number of threads arranged " + "along the M dimension."); + + static_assert(Shape::kRow > 0, "Shape::kRow must be greater than zero."); + static_assert(Shape::kColumn > 0, "Shape::kColumn must be greater than zero."); + static_assert(Policy::WarpShape::kRow > 0, "Policy::WarpShape::kRow must be greater than zero."); + static_assert(Shape::kRow / Policy::WarpShape::kRow > 0, + "Shape::kRow / Policy::WarpShape::kRow must be greater than zero."); + + // Activations loaded by threadblock + static int const ThreadActivationShapeH = (ThreadOutputShape::kH - 1) * StrideShape::kRow + + (FilterShape::kRow - 1) * DilationShape::kRow + 1; + + static int const ThreadActivationShapeW = (ThreadOutputShape::kW - 1) * StrideShape::kColumn + + (FilterShape::kColumn - 1) * DilationShape::kColumn + 1; + + using ThreadActivationShape = cutlass::conv:: + TensorNHWCShape<1, ThreadActivationShapeH, ThreadActivationShapeW, ThreadOutputShape::kC>; + + // Thread-level shape of a fragment + using ThreadShape = + MatrixShape; + + static_assert(!(ThreadShape::kColumn % Policy::LaneMmaShape::kN), + "Thread-level GEMM must be divisible by Policy::LaneMmaShape."); + + /// Number of individual loads + using Iterations = + MatrixShape; + + using ThreadTileCount = MatrixShape; + + /// Fragment object holding a thread's part of a tile + using Fragment = Array; + + protected: + /// Internal reference + cutlass::TensorRef, layout::RowMajor> ref_; + + Array + activation[ThreadActivationShape::kH][ThreadActivationShape::kW][Iterations::kColumn]; + int iterator_r_; + int iterator_s_; + + + MatrixCoord lane_offset_; + + public: + /// Default ctor constructs null iterator + CUTLASS_HOST_DEVICE + DepthwiseDirect2dConvSimtTileIterator() {} + + /// Constructor from TensorRef + CUTLASS_HOST_DEVICE + DepthwiseDirect2dConvSimtTileIterator(TensorRef ref, int lane_id) { + // compute offset based on thread ID and lane layout + typename Policy::LaneLayout lane_layout = Policy::get_lane_layout(); + + // Set channel offset + lane_offset_ = lane_layout.inverse(lane_id) * MatrixCoord(0, Policy::LaneMmaShape::kN); + + ref.add_coord_offset(lane_offset_); + + ref_.reset(reinterpret_cast *>(ref.data()), + ref.stride(0) / Policy::LaneMmaShape::kN); + + iterator_r_ = 0; + iterator_s_ = 0; + } + + /// Adds a pointer offset to internal pointer(s) to advance through memory + CUTLASS_HOST_DEVICE + DepthwiseDirect2dConvSimtTileIterator &add_pointer_offset(LongIndex offset) { + ref_.add_pointer_offset(offset); + return *this; + } + + /// Loads a fragment from memory at the location pointed to by the iterator. + template + CUTLASS_HOST_DEVICE void setup_initial_status( + Params const ¶ms) { + + // Get base HW offset of current threads + int threadgroup = threadIdx.x / (ThreadBlockOutputShape::kC / ThreadOutputShape::kC); + int base_h = + (threadgroup / (ThreadTileCount::kColumn)) * ThreadOutputShape::kH * StrideShape::kRow; + int base_w = + (threadgroup % (ThreadTileCount::kColumn)) * ThreadOutputShape::kW * StrideShape::kColumn; + + CUTLASS_PRAGMA_UNROLL + for (int h = 0; h < ThreadActivationShape::kH; ++h) { + CUTLASS_PRAGMA_UNROLL + for (int w = 0; w < ThreadActivationShape::kW; ++w) { + CUTLASS_PRAGMA_UNROLL + for (int col = 0; col < Iterations::kColumn; ++col) { + int offset = (base_h + h) * ActivationShape::kW + (base_w + w); + + void const *ptr = ref_.data() + ref_.offset({offset, col * Policy::WarpShape::kColumn}); + arch::shared_load(activation[h][w][col], ptr); + } + } + } + } + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_HOST_DEVICE + DepthwiseDirect2dConvSimtTileIterator &add_tile_offset(TensorCoord const &coord) { + // Set warp row and col start + lane_offset_ = + MatrixCoord({lane_offset_.row() + coord.row() * Shape::kRow, lane_offset_.column()}); + return *this; + } + + /// Advances an iterator along logical dimensions of matrix in units of whole tiles + CUTLASS_HOST_DEVICE + void advance(int32_t pointer_offset) { + ref_.reset(ref_.data() + pointer_offset / sizeof(Element) / Policy::LaneMmaShape::kN); + iterator_s_ = 0; + iterator_r_ = 0; + } + + /// Advances the iterator along the advance dimension + CUTLASS_HOST_DEVICE + DepthwiseDirect2dConvSimtTileIterator &operator++() { + ++iterator_s_; + if (iterator_s_ < FilterShape::kColumn) { + return *this; + } + + iterator_s_ = 0; + + ++iterator_r_; + if (iterator_r_ < FilterShape::kRow) { + return *this; + } + + iterator_r_ = 0; + return *this; + } + + /// Advances the iterator along the advance dimension + CUTLASS_HOST_DEVICE + DepthwiseDirect2dConvSimtTileIterator &operator--() { + // Do nothing + return *this; + } + + /// Loads a fragment from memory at the location pointed to by the iterator. (vector loads) + CUTLASS_HOST_DEVICE + void load_with_pointer_offset(Fragment &frag, Index pointer_offset) const { + Array *dst_ptr = + reinterpret_cast *>(&frag); + + CUTLASS_PRAGMA_UNROLL + for (int p = 0; p < ThreadOutputShape::kH; ++p) { + CUTLASS_PRAGMA_UNROLL + for (int q = 0; q < ThreadOutputShape::kW; ++q) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < Iterations::kColumn; ++n) { + const int h = p * StrideShape::kRow + iterator_r_ * DilationShape::kRow; + const int w = q * StrideShape::kColumn + iterator_s_ * DilationShape::kColumn; + + dst_ptr[n + q + p * ThreadOutputShape::kW] = activation[h][w][n]; + } + } + } + } + + /// Loads a fragment from memory at the location pointed to by the iterator. + CUTLASS_HOST_DEVICE + void load(Fragment &frag) const { load_with_pointer_offset(frag, 0); } + + /// Stores a fragment to memory at the location pointed to by the iterator + CUTLASS_HOST_DEVICE + void store_with_pointer_offset(Fragment const &frag, Index pointer_offset) const { + // Do nothing at present. + } + + /// Stores a fragment to memory at the location pointed to by the iterator + CUTLASS_HOST_DEVICE + void store(Fragment const &frag, Index pointer_offset) const { + store_with_pointer_offset(frag, 0); + } + + CUTLASS_DEVICE + void set_kgroup_index(int k_group) { + // no operation here + } +}; + +} // namespace warp +} // namespace conv +} // namespace cutlass diff --git a/include/cutlass/conv/warp/scale_bias_relu_transform.h b/include/cutlass/conv/warp/scale_bias_relu_transform.h index 5bcbfcd0bd..4da31ab818 100644 --- a/include/cutlass/conv/warp/scale_bias_relu_transform.h +++ b/include/cutlass/conv/warp/scale_bias_relu_transform.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -101,9 +101,8 @@ struct FpropScaleBiasReluTransform { "}\n" : "=r"(ptr_activations[0]) : "r"(ptr_scale_bias[0]), "r"(ptr_activations[0]), - "r"(ptr_scale_bias[1]), "n"(0x7eff7eff)); + "r"(ptr_scale_bias[1]), "n"(cutlass::arch::OOB_NAN_F16x2)); #else - // TODO: write emulation code assert(0); #endif } @@ -151,8 +150,8 @@ struct WgradScaleBiasReluTransform { #if 1 // CUDA + PTX version - bool h1_oob = (reinterpret_cast(ptr_activations[0].x) == 0x7eff); - bool h2_oob = (reinterpret_cast(ptr_activations[0].y) == 0x7eff); + bool h1_oob = (reinterpret_cast(ptr_activations[0].x) == cutlass::arch::OOB_NAN_F16); + bool h2_oob = (reinterpret_cast(ptr_activations[0].y) == cutlass::arch::OOB_NAN_F16); // Apply per channel scale+bias+relu if the data is not a special NaN // (0x7eff). If it is a special NaN (0x7eff), hard code the output to 0. @@ -161,7 +160,7 @@ struct WgradScaleBiasReluTransform { // out-of-bound because C x R x S can be an odd number. asm volatile( "{\n\t" - " fma.rn.f16x2.relu %0 , %1, %2, %3;\n" + " fma.rn.f16x2.relu %0, %1, %2, %3;\n" "}" : "=r"(reinterpret_cast(ptr_activations[0])) : "r"(ptr_scale_bias[0]), "r"(reinterpret_cast(ptr_activations[0])), @@ -195,10 +194,9 @@ struct WgradScaleBiasReluTransform { "}\n" : "=r"(reinterpret_cast(ptr_activations[0])) : "r"(ptr_scale_bias[0]), "r"(reinterpret_cast(ptr_activations[0])), - "r"(ptr_scale_bias[1]), "n"(0x7eff), "n"(0xffff0000), "n"(0x0000ffff)); + "r"(ptr_scale_bias[1]), "n"(cutlass::arch::OOB_NAN_F16), "n"(0xffff0000), "n"(0x0000ffff)); #endif #else - // TODO: write emulation code assert(0); #endif } diff --git a/include/cutlass/coord.h b/include/cutlass/coord.h index 1fe8ec0feb..fe884d7037 100644 --- a/include/cutlass/coord.h +++ b/include/cutlass/coord.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -37,7 +37,7 @@ #if defined(__CUDACC_RTC__) #include #else -#include +#include #endif #include "cutlass/cutlass.h" @@ -100,12 +100,21 @@ struct Coord { } } + /// Constructs from some other Coord + template + CUTLASS_HOST_DEVICE + Coord(Coord other) { + for (int i = 0; i < kRank; ++i) { + idx[i] = other[i]; + } + } + /// Returns a slice of the Coord which may be larger or smaller in rank /// than this. template CUTLASS_HOST_DEVICE - Coord slice(int start = 0, Index identity = 0) const { - Coord result; + Coord slice(int start = 0, Index identity = 0) const { + Coord result; for (int i = 0; i < Slice; ++i) { if (i + start < kRank) { result[i] = idx[i + start]; diff --git a/include/cutlass/core_io.h b/include/cutlass/core_io.h index eef4360267..40ae22246a 100644 --- a/include/cutlass/core_io.h +++ b/include/cutlass/core_io.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -31,7 +31,6 @@ /*! \file \brief Helpers for printing cutlass/core objects */ - #pragma once #include @@ -45,7 +44,7 @@ #include "cutlass/matrix_shape.h" #include "cutlass/layout/pitch_linear.h" #include "cutlass/tensor_view.h" -#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/gemm_enumerated_types.h" #include "cutlass/conv/convolution.h" #include "cutlass/conv/conv2d_problem_size.h" #include "cutlass/conv/conv3d_problem_size.h" @@ -252,8 +251,9 @@ namespace conv { inline std::ostream& operator<<(std::ostream& out, Conv2dProblemSize const& problem) { out << "NHWC: (" << problem.N << ", " << problem.H << ", " << problem.W << ", " << problem.C << ")" << std::endl - << "KRSC: (" << problem.K << ", " << problem.R << ", " << problem.S << ", " << problem.C << ")" << std::endl + << "KRSC: (" << problem.K << ", " << problem.R << ", " << problem.S << ", " << problem.C / problem.groups << ")" << std::endl << "NPQK: (" << problem.N << ", " << problem.P << ", " << problem.Q << ", " << problem.K << ")" << std::endl + << "groups: (" << problem.groups << ")" << std::endl << "Pad_h, Pad_w: (" << problem.pad_h << ", " << problem.pad_w << ")" << std::endl << "Stride_h, Stride_w: (" << problem.stride_h << ", " << problem.stride_w << ")" << std::endl << "Dilation_h, Dilation_w: (" << problem.dilation_h << ", " << problem.dilation_w << ")" << std::endl diff --git a/include/cutlass/cuda_host_adapter.hpp b/include/cutlass/cuda_host_adapter.hpp new file mode 100644 index 0000000000..2adfd2665f --- /dev/null +++ b/include/cutlass/cuda_host_adapter.hpp @@ -0,0 +1,412 @@ +/*************************************************************************************************** + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Interface betweeen a CUTLASS device-wide operator and CUDA. +*/ + +#pragma once + +#include +#include "cutlass/cutlass.h" +#include "cutlass/trace.h" + +#include "cutlass/platform/platform.h" +#if ! defined(__CUDACC_RTC__) +#include +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// NVRTC doesn't need definitions for these host classes + +#if ((__CUDACC_VER_MAJOR__ >= 12) || \ + ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 8))) \ + && !defined(__CUDACC_RTC__) +#define CUDA_HOST_ADAPTER_LAUNCH_ATTRIBUTES_ENABLED +#endif + +#if ((__CUDACC_VER_MAJOR__ >= 12) && !defined(__CUDACC_RTC__)) +#define CUDA_HOST_ADAPTER_TENSORMAP_ENABLED +#endif + +// Include for CUDA Driver API calls if any of these capabilities are enabled. +#if defined(CUDA_HOST_ADAPTER_LAUNCH_ATTRIBUTES_ENABLED) || \ + defined(CUDA_HOST_ADAPTER_TENSORMAP_ENABLED) + +#include + +#endif // defined(CUDA_HOST_ADAPTER_LAUNCH_ATTRIBUTES_ENABLED) || + // defined(CUDA_HOST_ADAPTER_TENSORMAP_ENABLED) + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// +// Macro-level guard for CUDA Host Adapter +// +#if !defined(CUTLASS_ENABLE_CUDA_HOST_ADAPTER) +#define CUTLASS_ENABLE_CUDA_HOST_ADAPTER false +#endif + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +///////////////////////////////////////////////////////////////////////////////////////////////// + + +#if !defined(__CUDACC_RTC__) + +#if ((__CUDACC_VER_MAJOR__ >= 12) || \ + ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 8))) +#include +#endif // (__CUDACC_VERSION__ >= 11.8) + +#include + +#define CUTLASS_CUDA_DRIVER_STRINGIFY(tok) #tok + +#if defined(CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL) + +#define CUTLASS_CUDA_DRIVER_WRAPPER_DECL(func, ver) \ + template \ + CUresult call_##func(Args... args) { \ + return func(args...); \ + } + +#else // defined(CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL) + +#if ((__CUDACC_VER_MAJOR__ >= 13) || \ + ((__CUDACC_VER_MAJOR__ == 12) && (__CUDACC_VER_MINOR__ >= 5))) \ + +#define CUTLASS_CUDA_DRIVER_WRAPPER_DECL(func, ver) \ + template \ + CUresult call_##func(Args... args) { \ + cudaDriverEntryPointQueryResult cuda_status; \ + void* pfn = nullptr; \ + cudaError_t cuda_err = cudaGetDriverEntryPointByVersion( \ + CUTLASS_CUDA_DRIVER_STRINGIFY(func), \ + &pfn, ver, \ + cudaEnableDefault, \ + &cuda_status); \ + if (cuda_status != cudaDriverEntryPointSuccess || \ + cuda_err != cudaSuccess) { \ + return CUDA_ERROR_UNKNOWN; \ + } \ + return reinterpret_cast(pfn)(args...); \ + } + +#else + +#define CUTLASS_CUDA_DRIVER_WRAPPER_DECL(func, ver) \ + template \ + CUresult call_##func(Args... args) { \ + cudaDriverEntryPointQueryResult cuda_status; \ + void* pfn = nullptr; \ + cudaError_t cuda_err = cudaGetDriverEntryPoint( \ + CUTLASS_CUDA_DRIVER_STRINGIFY(func), \ + &pfn, \ + cudaEnableDefault, \ + &cuda_status); \ + if (cuda_status != cudaDriverEntryPointSuccess || \ + cuda_err != cudaSuccess) { \ + return CUDA_ERROR_UNKNOWN; \ + } \ + return reinterpret_cast(pfn)(args...); \ + } + +#endif // (__CUDACC_VERSION__ >= 12.5) + +#endif // defined(CUTLASS_ENABLE_DIRECT_CUDA_DRIVER_CALL) + +#if (__CUDACC_VER_MAJOR__ >= 12) +CUTLASS_CUDA_DRIVER_WRAPPER_DECL(cuTensorMapEncodeTiled, 12000); +CUTLASS_CUDA_DRIVER_WRAPPER_DECL(cuTensorMapEncodeIm2col, 12000); +#endif + +#undef CUTLASS_CUDA_DRIVER_STRINGIFY + +#define CUTLASS_CUDA_DRIVER_WRAPPER_CALL(func) cutlass::call_##func + +#endif // !defined(__CUDACC_RTC__) + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// This class manages runtime CUlaunchAttribute that can be supplied to CudaHostAdapter +/// CudaHostLaunchAttributes will be an empty struct in earlier CTK where CUlaunchAttribute +/// is not introduced. +struct CudaHostLaunchAttributes { + +#if defined(CUDA_HOST_ADAPTER_LAUNCH_ATTRIBUTES_ENABLED) + + /// Reasonable maximum launch attributes that are commonly applied + static constexpr int32_t kMaximumAttributeCount = 5; + + /// Launch attributes + CUlaunchAttribute launch_attributes[kMaximumAttributeCount]; + int32_t attribute_count = 0; + + CUTLASS_HOST_DEVICE + CudaHostLaunchAttributes(CUlaunchAttribute *launch_attributes_ = nullptr, + int32_t attribute_count_ = 0) { + CUTLASS_ASSERT(attribute_count_ >= 0 && attribute_count_ < kMaximumAttributeCount); + for (int32_t i = 0; i < attribute_count_ && i < kMaximumAttributeCount; ++i) { + launch_attributes[i] = launch_attributes_[i]; + } + attribute_count = attribute_count_; + } + + CUTLASS_HOST_DEVICE + CUlaunchAttribute const* data() const { + return launch_attributes; + } + + CUTLASS_HOST_DEVICE + size_t size() const { + return attribute_count; + } + +#endif // (CUDA_HOST_ADAPTER_LAUNCH_ATTRIBUTES_ENABLED) + +}; + + +/// This class defines an object which abstracts interactions between the CUTLASS device-wide GEMM and +/// CUDA. The intention is to enable CUTLASS to be used with both the CUDA Runtime API and CUDA Driver API. +struct CudaHostAdapter { + + /// Limit the number of kernels + static constexpr int32_t kMaximumKernelCount = 4; + + /// Maximum cluster size + static constexpr int MaxClusterSize = 32; + + // + // Data members + // + + /// Handles + void *kernel_handles[kMaximumKernelCount]; + int32_t kernel_count = 0; + + CudaHostLaunchAttributes launch_attributes; + + // + // Methods + // + + /// Ctor + CudaHostAdapter() = default; + + /// Dtor + virtual ~CudaHostAdapter() = default; + + /// Copy Ctor + CUTLASS_HOST_DEVICE + CudaHostAdapter(const CudaHostAdapter & rhs) + : kernel_count(rhs.kernel_count), + launch_attributes(rhs.launch_attributes) { + CUTLASS_ASSERT(rhs.kernel_count >= 0 && rhs.kernel_count < kMaximumKernelCount); + + for (int32_t i = 0; i < rhs.kernel_count && i < kMaximumKernelCount; ++i) { + kernel_handles[i] = rhs.kernel_handles[i]; + } + } + + /// Copy Assignment + CUTLASS_HOST_DEVICE + CudaHostAdapter& operator=(const CudaHostAdapter & rhs) { + CUTLASS_ASSERT(rhs.kernel_count >= 0 && rhs.kernel_count < kMaximumKernelCount); + for (int32_t i = 0; i < rhs.kernel_count && i < kMaximumKernelCount; ++i) { + kernel_handles[i] = rhs.kernel_handles[i]; + } + kernel_count = rhs.kernel_count; + + launch_attributes = rhs.launch_attributes; + + return *this; + } + + + /// Move ctor + CUTLASS_HOST_DEVICE + CudaHostAdapter(CudaHostAdapter && rhs) + : kernel_count(rhs.kernel_count), + launch_attributes(std::move(rhs.launch_attributes)) { + CUTLASS_ASSERT(rhs.kernel_count >= 0 && rhs.kernel_count < kMaximumKernelCount); + + for (int32_t i = 0; i < rhs.kernel_count && i < kMaximumKernelCount; ++i) { + kernel_handles[i] = rhs.kernel_handles[i]; + } + } + + // / Move assignment + CUTLASS_HOST_DEVICE + CudaHostAdapter& operator=(CudaHostAdapter && rhs) { + CUTLASS_ASSERT(rhs.kernel_count >= 0 && rhs.kernel_count < kMaximumKernelCount); + for (int32_t i = 0; i < rhs.kernel_count && i < kMaximumKernelCount; ++i) { + kernel_handles[i] = rhs.kernel_handles[i]; + } + kernel_count = rhs.kernel_count; + launch_attributes = std::move(rhs.launch_attributes); + return *this; + } + + /// Ctor + CUTLASS_HOST_DEVICE + CudaHostAdapter(void **kernel_handles_, + int32_t kernel_count_, + CudaHostLaunchAttributes const &launch_attributes_ = { }) + : kernel_count(kernel_count_), + launch_attributes(launch_attributes_) { + CUTLASS_ASSERT(kernel_count >= 0 && kernel_count < kMaximumKernelCount); + + for (int32_t i = 0; i < kernel_count && i < kMaximumKernelCount; ++i) { + kernel_handles[i] = kernel_handles_[i]; + } + } + + /// Returns true if the CudaHostAdapter is empty (kernel_count == 0) + CUTLASS_HOST_DEVICE + bool empty() const { return !kernel_count; } + + /// Returns kernel_count + CUTLASS_HOST_DEVICE + size_t size() const { return static_cast(kernel_count); } + + /// Queries the occupancy of a kernel + virtual Status query_occupancy( + int32_t *device_sms, + int32_t *sm_occupancy, + int32_t kernel_index, + int32_t thread_count, + int32_t smem_size) const = 0; + + /// Launches a kernel without using Threadblock Clusters. + virtual Status launch( + dim3 const grid_dims, + dim3 const block_dims, + size_t const smem_size, + cudaStream_t cuda_stream, + void** kernel_params, + int32_t kernel_index) const = 0; + + /// Launches a kernel using the CUDA Extensible Launch API and Threadblock Clusters. + virtual Status launch( + dim3 const grid_dims, + dim3 const cluster_dims, + dim3 const block_dims, + size_t const smem_size, + cudaStream_t cuda_stream, + void** kernel_params, + int32_t kernel_index) const = 0; + +#if defined(CUDA_HOST_ADAPTER_TENSORMAP_ENABLED) + + /// Create a tensor map descriptor object representing im2col memory region. + virtual CUresult tensorMapEncodeIm2col ( + CUtensorMap* tensorMap, + CUtensorMapDataType tensorDataType, + cuuint32_t tensorRank, + void* globalAddress, + const cuuint64_t* globalDim, + const cuuint64_t* globalStrides, + const int* pixelBoxLowerCorner, + const int* pixelBoxUpperCorner, + cuuint32_t channelsPerPixel, + cuuint32_t pixelsPerColumn, + const cuuint32_t* elementStrides, + CUtensorMapInterleave interleave, + CUtensorMapSwizzle swizzle, + CUtensorMapL2promotion l2Promotion, + CUtensorMapFloatOOBfill oobFill) const = 0; + + /// Create a tensor map descriptor object representing tiled memory region. + virtual CUresult tensorMapEncodeTiled ( + CUtensorMap* tensorMap, + CUtensorMapDataType tensorDataType, + cuuint32_t tensorRank, + void* globalAddress, + const cuuint64_t* globalDim, + const cuuint64_t* globalStrides, + const cuuint32_t* boxDim, + const cuuint32_t* elementStrides, + CUtensorMapInterleave interleave, + CUtensorMapSwizzle swizzle, + CUtensorMapL2promotion l2Promotion, + CUtensorMapFloatOOBfill oobFill) const = 0; + + /// Modify an existing tensor map descriptor with an updated global address. + virtual CUresult tensorMapReplaceAddress( + CUtensorMap* tensorMap, + void* globalAddress) const = 0; + +#endif // defined(CUDA_HOST_ADAPTER_TENSORMAP_ENABLED) + +protected: + + /** + * Fills a buffer in Global Memory with a byte sequence copied from host memory. + * This function can be overriden to dispatch to the appropriate cuMemsetD*Async API + */ + virtual Status memsetDeviceImpl( + void* destination, ///< Device memory pointer to be filled + void const* fill_value, ///< Value to be filled in the buffer + size_t fill_size, ///< Size of the data type to be used for filling the buffer + size_t count, ///< Number of elements of size fill_size + cudaStream_t stream) const = 0; + +public: + + /// Fills a buffer in Global Memory with a byte sequence copied from host memory + template + CUTLASS_HOST_DEVICE + Status memsetDevice( + void* destination, + FillValueType fill_value, + size_t count, + cudaStream_t stream) const { + return this->memsetDeviceImpl( + destination, + &fill_value, + sizeof(FillValueType), + count, + stream); + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/cutlass.h b/include/cutlass/cutlass.h index ebc4c1e905..e12616a201 100644 --- a/include/cutlass/cutlass.h +++ b/include/cutlass/cutlass.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -35,53 +35,13 @@ #pragma once -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifdef CUTLASS_NAMESPACE -#define cutlass CUTLASS_NAMESPACE -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define CUTLASS_UNUSED(expr) do { (void)(expr); } while (0) - -#if !defined(__CUDACC_RTC__) - -#include - -#if defined(_MSC_VER) - #define CUTLASS_NOT_IMPLEMENTED() assert(0 && __FUNCSIG__) -#else - #define CUTLASS_NOT_IMPLEMENTED() assert(0 && __PRETTY_FUNCTION__) -#endif - -#else - -#if defined(_MSC_VER) - #define CUTLASS_NOT_IMPLEMENTED() assert(0 && __FUNCSIG__) -#else - #define CUTLASS_NOT_IMPLEMENTED() assert(0 && __PRETTY_FUNCTION__) -#endif - -#endif +#include "cutlass/arch/synclog.hpp" +#include "cutlass/detail/helper_macros.hpp" //////////////////////////////////////////////////////////////////////////////////////////////////// namespace cutlass { -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) -#define CUTLASS_HOST_DEVICE __forceinline__ __device__ __host__ -#define CUTLASS_DEVICE __forceinline__ __device__ -#elif defined(__CUDACC_RTC__) -#define CUTLASS_HOST_DEVICE __forceinline__ __device__ -#define CUTLASS_DEVICE __forceinline__ __device__ -#else -#define CUTLASS_HOST_DEVICE inline -#define CUTLASS_DEVICE inline -#endif - /// Status code returned by CUTLASS operations enum class Status { kSuccess, ///< Operation was successful. @@ -132,58 +92,64 @@ static char const* cutlassGetStatusString(cutlass::Status status) { //////////////////////////////////////////////////////////////////////////////////////////////////// - -#ifndef CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED -#define CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED 0 -#endif - - -// CUDA 10.1 introduces the mma instruction -#if !defined(CUTLASS_ENABLE_TENSOR_CORE_MMA) -#define CUTLASS_ENABLE_TENSOR_CORE_MMA 0 -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// - -#define CUTLASS_ASSERT(x) assert(x) +static const int NumThreadsPerWarp = 32; +static const int NumThreadsPerWarpGroup = 128; +static const int NumWarpsPerWarpGroup = NumThreadsPerWarpGroup / NumThreadsPerWarp; +static const int NumThreadsPerHalfWarp = NumThreadsPerWarp / 2; +static const int NumThreadsPerQuad = 4; +static const int NumThreadsPerQuadPair = NumThreadsPerQuad * 2; //////////////////////////////////////////////////////////////////////////////////////////////////// -// CUTLASS_PRAGMA_(UNROLL|NO_UNROLL) optimization directives for the CUDA compiler. -#if defined(__CUDA_ARCH__) - #if defined(__CUDACC_RTC__) || (defined(__clang__) && defined(__CUDA__)) - #define CUTLASS_PRAGMA_UNROLL _Pragma("unroll") - #define CUTLASS_PRAGMA_NO_UNROLL _Pragma("unroll 1") +/// Helper function to return true when called by thread 0 of threadblock 0. +CUTLASS_HOST_DEVICE bool thread0() { + #if defined(__CUDA_ARCH__) + return (!threadIdx.x && !threadIdx.y && !threadIdx.z) && (!blockIdx.x && !blockIdx.y && !blockIdx.z); #else - #define CUTLASS_PRAGMA_UNROLL #pragma unroll - #define CUTLASS_PRAGMA_NO_UNROLL #pragma unroll 1 + return false; #endif +} - #define CUTLASS_GEMM_LOOP CUTLASS_PRAGMA_NO_UNROLL - -#else - - #define CUTLASS_PRAGMA_UNROLL - #define CUTLASS_PRAGMA_NO_UNROLL - #define CUTLASS_GEMM_LOOP - -#endif - -//////////////////////////////////////////////////////////////////////////////////////////////////// +/// Returns a lane index in the warp. The threads in warp may not be convergent +CUTLASS_DEVICE +int canonical_lane_idx() { + #if defined(__CUDA_ARCH__) + return threadIdx.x % NumThreadsPerWarp; + #else + return 0; + #endif +} -static const int NUM_THREADS_PER_WARP = 32; -static const int NUM_THREADS_PER_HALF_WARP = NUM_THREADS_PER_WARP / 2; -static const int NUM_THREADS_PER_QUAD = 4; -static const int NUM_THREADS_PER_QUAD_PAIR = NUM_THREADS_PER_QUAD * 2; +/// Returns a warp-uniform value indicating the canonical warp index of the calling threads. +/// Threads within the warp must be converged. +CUTLASS_DEVICE +int canonical_warp_idx_sync() { + #if defined(__CUDA_ARCH__) + return __shfl_sync(0xffffffff, threadIdx.x / NumThreadsPerWarp, 0); + #else + return 0; + #endif +} -//////////////////////////////////////////////////////////////////////////////////////////////////// +/// Returns a warp index in the CTA. The threads in warp may not be convergent +/// As it doesn't sync the warp, it faster and allows forward progress +CUTLASS_DEVICE +int canonical_warp_idx() { + #if defined(__CUDA_ARCH__) + return threadIdx.x / NumThreadsPerWarp; + #else + return 0; + #endif +} -/// Helper function to return true when called by thread 0 of threadblock 0. -CUTLASS_HOST_DEVICE bool thread0() { +/// Returns a warp-uniform value indicating the canonical warp group index of the calling threads. +/// Threads within the warp must be converged. +CUTLASS_DEVICE +int canonical_warp_group_idx() { #if defined(__CUDA_ARCH__) - return (!threadIdx.x && !threadIdx.y && !threadIdx.z) && (!blockIdx.x && !blockIdx.y && !blockIdx.z); + return __shfl_sync(0xffffffff, threadIdx.x / NumThreadsPerWarpGroup, 0); #else - return false; + return 0; #endif } @@ -192,4 +158,3 @@ CUTLASS_HOST_DEVICE bool thread0() { } // namespace cutlass //////////////////////////////////////////////////////////////////////////////////////////////////// - diff --git a/include/cutlass/detail/collective.hpp b/include/cutlass/detail/collective.hpp new file mode 100644 index 0000000000..9d8f9e2f1d --- /dev/null +++ b/include/cutlass/detail/collective.hpp @@ -0,0 +1,64 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/container/tuple.hpp" +#include "cute/layout.hpp" // cute::size(shape) +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +struct deduce_mixed_width_dtype { +static_assert(I >= 0u && I <= 2u, "Valid indices are 0, 1, and 2, which represent Operand, Scale, and Bias, respectively."); + +private: + using underlying_tuple = cute::conditional_t::value, Tuple, cute::tuple>; + static constexpr size_t valid_index = cute::min(I, cute::tuple_size_v - 1); + +public: + using type = cute::conditional_t<(I < cute::tuple_size_v), + cute::tuple_element_t, + void>; +}; + +template +using deduce_mixed_width_dtype_t = typename deduce_mixed_width_dtype::type; + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::gemm::collective diff --git a/include/cutlass/detail/collective/mixed_input_utils.hpp b/include/cutlass/detail/collective/mixed_input_utils.hpp new file mode 100644 index 0000000000..c740eb98b2 --- /dev/null +++ b/include/cutlass/detail/collective/mixed_input_utils.hpp @@ -0,0 +1,1017 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/numeric_conversion.h" + +#include "cute/util/type_traits.hpp" +#include "cute/arch/copy_sm90.hpp" +#include "cute/numeric/arithmetic_tuple.hpp" + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + +// The universal converter +template < + class SrcType, + class DstType, + class LayoutIn, + class LayoutOut +> +struct LayoutAwareConvertImpl { + template + CUTLASS_DEVICE + static void convert( + cute::Tensor const& src, + cute::Tensor & dst) { + + static_assert(cute::is_same_v && + cute::is_same_v); + static_assert(cute::cosize_v == cute::cosize_v); + constexpr int N = decltype(cute::max_common_vector(LayoutIn{}, LayoutOut{})){}; + using SrcArray = cutlass::Array; + using DstArray = cutlass::Array; + using Converter = cutlass::NumericArrayConverter; + auto&& src_vm = cute::recast(src); + auto&& dst_vm = cute::recast(dst); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i BF16 with [02461357] value order +template <> +struct LayoutAwareConvertImpl< + cutlass::int4b_t, + cutlass::bfloat16_t, + cute::Layout, cute::Stride<_4,_1>>, + cute::Layout<_8> +> { + template + CUTLASS_DEVICE + static void convert( + cute::Tensor, cute::Stride<_4,_1>> + > const& src, + cute::Tensor + >& dst) { + + static_assert(cute::is_same_v && + cute::is_same_v); + using SrcArray = cutlass::Array; + using DstArray = cutlass::Array; + using RegArray = cutlass::AlignedArray; + + auto&& src_reg = cute::recast(src)(0); + auto&& r = cute::recast(dst)(0); + CUTLASS_PRAGMA_UNROLL + for (size_t ii = 0; ii < RegArray::kElements; ++ii) { + r[ii] = src_reg >> (4 * (ii)); + static constexpr uint32_t xor_mask = 0x43084308; + static constexpr uint32_t lo_mask = 0x000F000F; + static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa; + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii]) + : "n"(lo_mask), "n"(xor_mask), "n"(immLut)); + static constexpr uint32_t lo_bias = xor_mask; // 0x43084308, {136, 136} + { + __nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]); + bf16x2_val = __hsub2(bf16x2_val, + reinterpret_cast(lo_bias)); + } + } + } +}; + +// Specialization for UINT4 -> BF16 with [02461357] value order +template <> +struct LayoutAwareConvertImpl< + cutlass::uint4b_t, + cutlass::bfloat16_t, + cute::Layout, cute::Stride<_4,_1>>, + cute::Layout<_8> +> { + template + CUTLASS_DEVICE + static void convert( + cute::Tensor, cute::Stride<_4,_1>> + > const& src, + cute::Tensor + >& dst) { + + static_assert(cute::is_same_v && + cute::is_same_v); + using SrcArray = cutlass::Array; + using DstArray = cutlass::Array; + using RegArray = cutlass::AlignedArray; + + auto&& src_reg = cute::recast(src)(0); + auto&& r = cute::recast(dst)(0); + CUTLASS_PRAGMA_UNROLL + for (size_t ii = 0; ii < RegArray::kElements; ++ii) { + r[ii] = src_reg >> (4 * (ii)); + static constexpr uint32_t or_mask = 0x43004300; + static constexpr uint32_t lo_mask = 0x000F000F; + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii]) + : "n"(lo_mask), "n"(or_mask), "n"(immLut)); + static constexpr uint32_t lo_bias = or_mask; // 0x43004300, {128, 128} + { + __nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]); + bf16x2_val = __hsub2(bf16x2_val, + reinterpret_cast(lo_bias)); + } + } + } +}; + +// Specialization for INT4 -> FP16 with [02461357] value order +template <> +struct LayoutAwareConvertImpl< + cutlass::int4b_t, + cutlass::half_t, + cute::Layout, cute::Stride<_4,_1>>, + cute::Layout<_8> +> { + template + CUTLASS_DEVICE + static void convert( + cute::Tensor, cute::Stride<_4,_1>> + > const& src, + cute::Tensor + >& dst) { + + static_assert(cute::is_same_v && + cute::is_same_v); + using SrcArray = cutlass::Array; + using DstArray = cutlass::Array; + using RegArray = cutlass::AlignedArray; + + auto&& src_reg = cute::recast(src)(0); + auto&& r = cute::recast(dst)(0); + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ii += 2) { + auto src_ = src_reg >> (4 * (ii)); + r[ii + 0] = src_; + r[ii + 1] = src_; + static constexpr uint32_t lo_xor_mask = 0x64086408; + static constexpr uint32_t hi_xor_mask = 0x64806480; + static constexpr uint32_t lo_mask = 0x000F000F; + static constexpr uint32_t hi_mask = 0x00F000F0; + static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa; + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii + 0]) + : "n"(lo_mask), "n"(lo_xor_mask), "n"(immLut)); + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii + 1]) + : "n"(hi_mask), "n"(hi_xor_mask), "n"(immLut)); + static constexpr uint32_t lo_bias = 0x64086408; // {1032, 1032} + static constexpr uint32_t hi_bias = 0xD480D480; // {-72, -72} + static constexpr uint32_t hi_scale = 0x2C002C00; // {1/16, 1/16} + { + half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 0]); + fp16x2_val = __hsub2(fp16x2_val, + reinterpret_cast(lo_bias)); + } + { + half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 1]); + fp16x2_val = __hfma2(fp16x2_val, + reinterpret_cast(hi_scale), + reinterpret_cast(hi_bias)); + } + } + } +}; + +// Specialization for UINT4 -> FP16 with [02461357] value order +template <> +struct LayoutAwareConvertImpl< + cutlass::uint4b_t, + cutlass::half_t, + cute::Layout, cute::Stride<_4,_1>>, + cute::Layout<_8> +> { + template + CUTLASS_DEVICE + static void convert( + cute::Tensor, cute::Stride<_4,_1>> + > const& src, + cute::Tensor + >& dst) { + + static_assert(cute::is_same_v && + cute::is_same_v); + using SrcArray = cutlass::Array; + using DstArray = cutlass::Array; + using RegArray = cutlass::AlignedArray; + + auto&& src_reg = cute::recast(src)(0); + auto&& r = cute::recast(dst)(0); + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ii += 2) { + auto src_ = src_reg >> (4 * (ii)); + r[ii + 0] = src_; + r[ii + 1] = src_; + static constexpr uint32_t or_mask = 0x64006400; + static constexpr uint32_t lo_mask = 0x000F000F; + static constexpr uint32_t hi_mask = 0x00F000F0; + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii]) + : "n"(lo_mask), "n"(or_mask), "n"(immLut)); + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii + 1]) + : "n"(hi_mask), "n"(or_mask), "n"(immLut)); + static constexpr uint32_t lo_bias = or_mask; // 0x64006400, {1024, 1024} + static constexpr uint32_t hi_bias = 0xD400D400; // {-64, -64} + static constexpr uint32_t hi_scale = 0x2C002C00; // {1/16, 1/16} + { + half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 0]); + fp16x2_val = __hsub2(fp16x2_val, + reinterpret_cast(lo_bias)); + } + { + half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii + 1]); + fp16x2_val = __hfma2(fp16x2_val, + reinterpret_cast(hi_scale), + reinterpret_cast(hi_bias)); + } + } + } +}; + +// Specialization for E5M2 -> FP16 with [3120] value order +template <> +struct LayoutAwareConvertImpl< + cutlass::float_e5m2_t, + cutlass::half_t, + cute::Layout, cute::Stride<_2,_1>>, + cute::Layout<_4> +> { + template + CUTLASS_DEVICE + static void convert( + cute::Tensor, cute::Stride<_2,_1>> + > const& src, + cute::Tensor + >& dst) { + + static_assert(cute::is_same_v && + cute::is_same_v); + using SrcArray = cutlass::Array; + using DstArray = cutlass::Array; + using RegArray = cutlass::AlignedArray; + + auto&& src_reg = cute::recast(src)(0); + auto&& r = cute::recast(dst)(0); + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + // in registers: a3, a1, a2, a0 + r[RegArray::kElements - ii - 1] = src_reg << (8 * (ii)); + + static constexpr uint32_t and_mask = 0xFF00FF00; + asm volatile( + "{\n" + " and.b32 %0, %0, %1;\n" + "}\n" + : "+r"(r[ii]) + : "n"(and_mask)); + } + } +}; + +// Specialization for INT8 -> BF16 with [3120] value order +template <> +struct LayoutAwareConvertImpl< + cutlass::int8_t, + cutlass::half_t, + cute::Layout, cute::Stride<_2,_1>>, + cute::Layout<_4> +> { + template + CUTLASS_DEVICE + static void convert( + cute::Tensor, cute::Stride<_2,_1>> + > const& src, + cute::Tensor + >& dst) { + + static_assert(cute::is_same_v && + cute::is_same_v); + using SrcArray = cutlass::Array; + using DstArray = cutlass::Array; + using RegArray = cutlass::AlignedArray; + + auto&& src_reg = cute::recast(src)(0); + auto&& r = cute::recast(dst)(0); + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + uint32_t tmp0, tmp1; + r[ii] = src_reg >> (8 * (ii)); + static constexpr uint32_t or_mask = 0x43004300; + static constexpr uint32_t and_mask_0 = 0x007F007F; + static constexpr uint32_t and_mask_1 = 0x00800080; + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + asm volatile( + "{\n" + " lop3.b32 %0, %1, %2, %3, %4;\n" + "}\n" + : "=r"(tmp0) + : "r"(r[ii]), "n"(and_mask_0), "n"(or_mask), "n"(immLut)); + asm volatile( + "{\n" + " lop3.b32 %0, %1, %2, %3, %4;\n" + "}\n" + : "=r"(tmp1) + : "r"(r[ii]), "n"(and_mask_1), "n"(or_mask), "n"(immLut)); + { + __nv_bfloat162& bf16x2_val = reinterpret_cast<__nv_bfloat162&>(r[ii]); + bf16x2_val = __hsub2(reinterpret_cast<__nv_bfloat162 const&>(tmp0), + reinterpret_cast<__nv_bfloat162 const&>(tmp1)); + } + } + } +}; + +// Specialization for INT8 -> FP16 with [3120] value order +template <> +struct LayoutAwareConvertImpl< + cutlass::int8_t, + cutlass::bfloat16_t, + cute::Layout, cute::Stride<_2,_1>>, + cute::Layout<_4> +> { + template + CUTLASS_DEVICE + static void convert( + cute::Tensor, cute::Stride<_2,_1>> + > const& src, + cute::Tensor + >& dst) { + + static_assert(cute::is_same_v && + cute::is_same_v); + using SrcArray = cutlass::Array; + using DstArray = cutlass::Array; + using RegArray = cutlass::AlignedArray; + + auto&& src_reg = cute::recast(src)(0); + auto&& r = cute::recast(dst)(0); + CUTLASS_PRAGMA_UNROLL + for (int ii = 0; ii < RegArray::kElements; ++ii) { + r[ii] = src_reg >> (8 * (ii)); + static constexpr uint32_t xor_mask = 0x64806480; + static constexpr uint32_t and_mask = 0x00FF00FF; + static constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa; + asm volatile( + "{\n" + " lop3.b32 %0, %0, %1, %2, %3;\n" + "}\n" + : "+r"(r[ii]) + : "n"(and_mask), "n"(xor_mask), "n"(immLut)); + { + static constexpr uint32_t bias = 0x64806480; + __half2& fp16x2_val = reinterpret_cast<__half2&>(r[ii]); + fp16x2_val = __hsub2(fp16x2_val, + reinterpret_cast<__half2 const&>(bias)); + } + } + } +}; + +template < + class EngineIn, + class EngineOut, + class LayoutIn, + class LayoutOut +> +CUTLASS_DEVICE +void LayoutAwareConvert( // Accept mutable temporaries + cute::Tensor const& src, + cute::Tensor && dst) { + + LayoutAwareConvert(src, dst); +} +template < + class EngineIn, + class EngineOut, + class LayoutIn, + class LayoutOut +> +CUTLASS_DEVICE +void LayoutAwareConvert( + cute::Tensor const& src, + cute::Tensor & dst) { + + using SrcType = typename EngineIn::value_type; + using DstType = typename EngineOut::value_type; + Tensor src_vm = coalesce(src); + Tensor dst_vm = coalesce(dst); + Layout src_layout = src_vm.layout(); + Layout dst_layout = dst_vm.layout(); + LayoutAwareConvertImpl::convert(src_vm, dst_vm); +} + +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::gemm::collective::detail { + +template +static constexpr +CUTLASS_HOST_DEVICE +auto get_logical_ptr(PointerType const* ptr) { + if constexpr (cute::sizeof_bits_v < 8) { + return subbyte_iterator(ptr); + } + else { + return ptr; + } +} +template +static constexpr +CUTLASS_HOST_DEVICE +auto get_smem_layout(LayoutAtom layout_atom, TileShape const& tile_shape, Stride const& stride) { + if constexpr (not cute::is_layout::value) { + return tile_to_shape( + layout_atom, + append(tile_shape, Int{}), + cute::conditional_t< ::cutlass::gemm::detail::is_major<0,Stride>(), Step<_2,_1,_3>, Step<_1,_2,_3>>{}); + } + else { + auto gmem_tile = composition(stride, tile_shape); + return make_layout_like(append(gmem_tile, make_layout(Int{}, 0))); + } +} +template +static constexpr +CUTLASS_HOST_DEVICE +auto get_gmem_layout(Shape const& shape, Stride const& stride) { + if constexpr (not cute::is_layout::value) { + return make_layout(shape, stride); + } + else { + return stride; + } +} + +template +struct MixedInputUtils { +private: + using KernelSchedule = typename Collective::KernelSchedule; + using ConversionMode = typename Collective::ConversionMode; + using SmemLayoutA = typename Collective::SmemLayoutA; + using SmemLayoutB = typename Collective::SmemLayoutB; + using SmemLayoutScale = typename Collective::SmemLayoutScale; + using SwappedElementA = typename Collective::SwappedElementA; + using SwappedElementB = typename Collective::SwappedElementB; + using RealSwappedElementA = typename Collective::RealSwappedElementA; + using RealSwappedElementB = typename Collective::RealSwappedElementB; + using ElementScale = typename Collective::ElementScale; + using ElementZero = typename Collective::ElementZero; + using SmemCopyAtomScale = typename Collective::SmemCopyAtomScale; + static constexpr auto KernelConversionMode = Collective::KernelConversionMode; + static constexpr auto ModeHasScales = Collective::ModeHasScales; + static constexpr auto UseScaleLookupTable = Collective::UseScaleLookupTable; + +public: + static constexpr auto + elements_per_smem_scale() { + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return 0; + } + else if constexpr (ModeHasScales) { + return cute::cosize_v; + } + else { + static_assert(cutlass::detail::dependent_false, "Type not handled in scale smem allocation."); + } + } + + static constexpr auto + elements_per_smem_zero() { + if constexpr (KernelConversionMode == ConversionMode::DirectConvert || + KernelConversionMode == ConversionMode::ConvertAndScale ) { + return 0; + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + return cute::cosize_v; + } + else { + static_assert(cutlass::detail::dependent_false, "Type not handled in scale smem allocation."); + } + } + + // These methods use some the public members of the class. For that reason, we define them after the public section. + static constexpr uint32_t + compute_tma_transaction_bytes_mk() { + return cutlass::bits_to_bytes(size<0>(SmemLayoutA{}) * size<1>(SmemLayoutA{}) * static_cast(cute::sizeof_bits_v)); + } + + static constexpr uint32_t + compute_tma_transaction_bytes_nk() { + return cutlass::bits_to_bytes(size<0>(SmemLayoutB{}) * size<1>(SmemLayoutB{}) * static_cast(cute::sizeof_bits_v)); + } + + static constexpr uint32_t + compute_tma_transaction_bytes_extra() { + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return 0; + } + else if constexpr (ModeHasScales) { + constexpr uint32_t scale_tx_bytes = cutlass::bits_to_bytes(size<0>(SmemLayoutScale{}) * size<1>(SmemLayoutScale{}) * static_cast(cute::sizeof_bits_v)); + static_assert(scale_tx_bytes % 128 == 0, "Each scale stage must be 128B aligned."); // required by TMA + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return scale_tx_bytes; + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + // Scale and zero share smem layout + constexpr uint32_t zero_tx_bytes = cutlass::bits_to_bytes(size<0>(SmemLayoutScale{}) * size<1>(SmemLayoutScale{}) * static_cast(cute::sizeof_bits_v)); + static_assert(zero_tx_bytes % 128 == 0, "Each zero stage must be 128B aligned."); // required by TMA + return scale_tx_bytes + zero_tx_bytes; + } + else { + static_assert(cutlass::detail::dependent_false, "Type not handled in tma transaction bytes computation."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Type not handled in tma transaction bytes computation."); + } + } + + /// Utilities to copy A and extra inputs from smem to RF + template + CUTLASS_DEVICE + static void copy_tensors_MK( + SmemTiledCopyA const& smem_tiled_copy_A, + TensorASmemView const& tCsA, + TensorACopyView& tCrA_copy_view, + cute::tuple const& partitioned_mma_extra_info, + cute::tuple const& tiled_copy_and_views, + int k_block, + int read_stage) { + + copy(smem_tiled_copy_A, tCsA(_,_,k_block,read_stage), tCrA_copy_view(_,_,k_block)); + + if (k_block == 0) { + // We are starting a new k-tile so copy the scale + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // nothing to do + } + else if constexpr (ModeHasScales) { + auto smem_tiled_copy_S = cute::get<0>(tiled_copy_and_views); + auto tCrS_copy_view = cute::get<1>(tiled_copy_and_views); + auto tCsS = cute::get<0>(partitioned_mma_extra_info); + copy(smem_tiled_copy_S, tCsS(_,_,k_block,read_stage), tCrS_copy_view(_,_,k_block)); + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + // Nothing extra to do + } else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + auto tCsZ = cute::get<2>(partitioned_mma_extra_info); + auto tCrZ_copy_view = cute::get<2>(tiled_copy_and_views); + copy(smem_tiled_copy_S, tCsZ(_,_,k_block,read_stage), tCrZ_copy_view(_,_,k_block)); + } else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } + } + + // The core converter uses a lookup table to converts i4 -> 8 bit value. + template + CUTLASS_DEVICE + static void lookup_table_convert( // Accept mutable temporaries + Tensor const& src, + Tensor && dst, + Tensor const& scales_neg, + Tensor const& scales_pos) { + + lookup_table_convert(src, dst, scales_neg, scales_pos); + } + template + CUTLASS_DEVICE + static void lookup_table_convert( + Tensor const& src, + Tensor & dst, + Tensor const& scales_neg, + Tensor const& scales_pos) { + + constexpr int N = cute::cosize(LayoutIn{}); + static_assert(N == 4 || N == 8); + static_assert(cosize(LayoutScale{}) <= N / 4, + "at least 4 consecutive weights must share the same scale."); + using SrcArray = cutlass::Array; + using DstArray = cutlass::Array; + using RegArray = cutlass::AlignedArray; + + // View the input as reg + auto&& src_reg = cute::recast(src)(0); + auto&& r = cute::recast(dst)(0); + + // Determines if to get from the signed or unsigned candidates + static constexpr uint32_t immLut = (0xf0 & 0xcc) | 0xaa; + uint32_t sign; // ((reg & 0x88888888) | 0x64206420) >> 1 + asm volatile( + "{\n" + " lop3.b32 %0, %1, %2, %3, %4;\n" \ + "}\n" + : "=r"(sign) + : "r"(src_reg), "n"(0x88888888), "n"(0x64206420), "n"(immLut) + ); + sign = sign >> 1; + + // Ignore sign bit when indexing into LUT + uint32_t lut_idx = src_reg & 0x77777777; + Tensor scales_neg_ = cute::filter(scales_neg); + Tensor scales_pos_ = cute::filter(scales_pos); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < N / 4; ++i, lut_idx >>=16, sign >>=16) { + auto&& scale_neg_ = reinterpret_cast const&>(scales_neg_(i)); + auto&& scale_pos_ = reinterpret_cast const&>(scales_pos_(i)); + asm volatile( + "{\n" + " .reg .b32 pos, neg ;\n" \ + " prmt .b32 neg, %3, %4, %1 ;\n" \ + " prmt .b32 pos, %5, %6, %1 ;\n" \ + " prmt .b32 %0, pos, neg, %2 ;\n" \ + "}\n" + : "=r"(r[i]) + : "r"(lut_idx), "r"(sign), "r"(scale_neg_[0]), "r"(scale_neg_[1]), "r"(scale_pos_[0]), "r"(scale_pos_[1]) + ); + } + } + + /// Utilities to dequantize A. + template + CUTLASS_DEVICE + static void static_check_scale(Layout const& tensor) { + static_assert(shape<0>(Layout{}) >= 4 && stride<0>(Layout{}) == 0, "At least 4 adjacent weights in a thread must share the same scale."); + } + template + CUTLASS_DEVICE + static void static_check_scale(Tensor const& tensor) { + static_check_scale(flatten(Layout{})); + } + template + CUTLASS_DEVICE + static void dequantize_A_kblock( + Tensor const& tCrA_load, + Tensor& tCrA_mma, + cute::tuple& partitioned_extra_info, + int const k_block) { + + static_assert(is_rmem::value, "Input tensor for A conversion must come from registers"); + static_assert(is_rmem::value, "Output tensor for A conversion must come from registers"); + static_assert(cosize_v == cosize_v); + static_assert(size_v == cosize_v); + static_assert(size_v == cosize_v); + using SrcType = typename EngineIn::value_type; + using DstType = typename EngineOut::value_type; + + Tensor src = tCrA_load(_, _, k_block); + Tensor dst = tCrA_mma(_, _, k_block); + + CUTE_STATIC_ASSERT_V(size(src(_, 0)) == cosize(src(_, 0).layout()), + "The first mode of tensor src must be contiguous in memory"); + // try to make the size of the first mode equal to 32bit + int constexpr NumValPerSrcReg = cute::min(decltype(size(src(_, 0)))::value, + ceil_div(32, sizeof_bits_v)); + Tensor src_vm = cute::group_modes<1,-1>(cute::zipped_divide(src, Int{})); + Tensor dst_vm = cute::group_modes<1,-1>(cute::zipped_divide(dst, Int{})); + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<1>(dst_vm); ++i) { + LayoutAwareConvert(src_vm(_, i), dst_vm(_, i)); + } + } + else if constexpr (UseScaleLookupTable) { + constexpr int num_elements = decltype(size(src))::value; + static_assert(is_same_v, "Lookup table only supports int4 being the quant type now."); + static_assert(sizeof_bits_v == 64, "Lookup table only supports 8 8bit scale values now."); + static_assert(num_elements % 4 == 0 && num_elements >= 4, "Lookup table requires a vector size of 4x when converting."); + + Tensor tCrS_neg = cute::get<1>(partitioned_extra_info); + auto&& tCrS_pos = cute::get<2>(partitioned_extra_info); // modification to its value is needed + Tensor scales_neg = tCrS_neg(_, _, k_block); + Tensor scales_pos = tCrS_pos(_, _, k_block); + CUTE_STATIC_ASSERT_V(cute::size(src) == cute::size(scales_neg)); + + static_check_scale(scales_neg); + static_check_scale(scales_pos); + Tensor scales_neg_vm = cute::group_modes<1,-1>(cute::zipped_divide(scales_neg, Int{})); + Tensor scales_pos_vm = cute::group_modes<1,-1>(cute::zipped_divide(scales_pos, Int{})); + + if (k_block == 0) { + Tensor scales_neg_vm_ = filter(scales_neg_vm); + Tensor scales_pos_vm_ = filter(scales_pos_vm); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(scales_neg_vm_.layout()); ++i) + { + auto&& scale_neg_ = reinterpret_cast const&>(scales_neg_vm_(i)); + auto&& scale_pos_ = reinterpret_cast &>(scales_pos_vm_(i)); + constexpr uint32_t immLut = (0xf0 & 0xcc) ^ 0xaa; + asm volatile( + "{\n" + " lop3 .b32 %0, %2, %4, %5, %6;\n" \ + " xor .b32 %1, %3, %5; \n" \ + "}\n" + : "=r"(scale_pos_[0]), "=r"(scale_pos_[1]) + : "r"(scale_neg_[0]), "r"(scale_neg_[1]), "n"(0xFFFFFF00), "n"(0x80808080), "n"(immLut) + ); + } + } + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<1>(dst_vm); ++i) { + lookup_table_convert(src_vm(_, i), dst_vm(_, i), scales_neg_vm(_, i), scales_pos_vm(_, i)); + } + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + Tensor scales = cute::get<1>(partitioned_extra_info)(_, _, k_block); + CUTE_STATIC_ASSERT_V(size(src) == size(scales)); + Tensor scales_vm = cute::group_modes<1,-1>(cute::zipped_divide(scales, Int{})); + + if constexpr (is_same_v) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<1>(dst_vm); ++i) { + LayoutAwareConvert(src_vm(_, i), dst_vm(_, i)); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<0>(dst_vm); ++j) { + dst_vm(j, i) *= scales_vm(j, i); + } + } + } + else { + auto stage = make_tensor_like(src_vm(_, 0)); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<1>(dst_vm); ++i) { + LayoutAwareConvert(src_vm(_, i), stage); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<0>(dst_vm); ++j) { + stage(j) *= scales_vm(j, i); + } + LayoutAwareConvert(stage, dst_vm(_, i)); + } + } + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + static_assert(is_same_v, "ElementScale and ElementZero must be the same."); + Tensor scales = cute::get<1>(partitioned_extra_info)(_, _, k_block); + Tensor zeros = cute::get<3>(partitioned_extra_info)(_, _, k_block); + CUTE_STATIC_ASSERT_V(size(src) == size(scales)); + CUTE_STATIC_ASSERT_V(size(src) == size(zeros)); + Tensor scales_vm = cute::group_modes<1,-1>(cute::zipped_divide(scales, Int{})); + Tensor zeros_vm = cute::group_modes<1,-1>(cute::zipped_divide(zeros, Int{})); + + if constexpr (is_same_v) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<1>(dst_vm); ++i) { + LayoutAwareConvert(src_vm(_, i), dst_vm(_, i)); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<0>(dst_vm); ++j) { + dst_vm(j, i) = dst_vm(j, i) * scales_vm(j, i) + zeros_vm(j, i); + } + } + } + else { + auto stage = make_tensor_like(src_vm(_, 0)); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<1>(dst_vm); ++i) { + LayoutAwareConvert(src_vm(_, i), stage); + CUTLASS_PRAGMA_UNROLL + for (int j = 0; j < size<0>(dst_vm); ++j) { + stage(j) = stage(j) * scales_vm(j, i) + zeros_vm(j, i); + } + LayoutAwareConvert(stage, dst_vm(_, i)); + } + } + } + else { + static_assert(cutlass::detail::dependent_false, "No A data is loaded."); + } + } + + /// Utilities for any additional inputs inside of the TMA load + template < + class Params, + class TensorStorage, + class... Ts + > + CUTLASS_DEVICE + static auto partition_extra_tma_inputs( + Params const& mainloop_params, + cute::tuple const& load_inputs, + TensorStorage& shared_tensors, + uint2 const& cluster_local_block_id, + int const m_coord, + int const l_coord) { + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + return cute::make_tuple(); + } + else if constexpr (ModeHasScales) { + Tensor sS = make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{}); // (BLK_M,BLK_K,PIPE) + Tensor gS_mkl = get<2>(load_inputs); + auto block_tma_s = mainloop_params.tma_load_scale.get_slice(cluster_local_block_id.y); + Tensor gS = gS_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + + Tensor tSgS = block_tma_s.partition_S(gS); // (TMA,TMA_M,TMA_K,k) + Tensor tSsS = block_tma_s.partition_D(sS); // (TMA,TMA_M,TMA_K,PIPE) + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return cute::make_tuple(tSgS, tSsS); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor sZ = make_tensor(make_smem_ptr(shared_tensors.smem_zero.begin()), SmemLayoutScale{}); // (BLK_M,BLK_K,PIPE) + Tensor gZ_mkl = get<3>(load_inputs); + auto block_tma_z = mainloop_params.tma_load_zero.get_slice(cluster_local_block_id.y); + Tensor gZ = gZ_mkl(_,_,m_coord,_,l_coord); // (BLK_M,BLK_K,k) + + Tensor tZgZ = block_tma_z.partition_S(gZ); // (TMA,TMA_M,TMA_K,k) + Tensor tZsZ = block_tma_z.partition_D(sZ); // (TMA,TMA_M,TMA_K,PIPE) + return cute::make_tuple(tSgS, tSsS, tZgZ, tZsZ); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for input partitioning."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled for input partitioning."); + } + } + + /// Utilities for partitioning extra inputs for loading from smem in the mainloop. + template < + class ThreadMma, + class TensorStorage + > + CUTLASS_DEVICE + static auto partition_extra_mma_info( + ThreadMma const& mma_thread_slice, + TensorStorage& shared_tensors) { + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // nothing to do + return cute::make_tuple(); + } + else if constexpr (UseScaleLookupTable) { + Tensor sS = make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{});// (BLK_M,BLK_SCALE_K,PIPE) + Tensor tCsS = mma_thread_slice.partition_A(sS); + Tensor tCrS_neg = make_tensor(mma_thread_slice.partition_fragment_A(sS(_,_,Int<0>{})).layout()); + Tensor tCrS_pos = make_tensor(mma_thread_slice.partition_fragment_A(sS(_,_,Int<0>{})).layout()); + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return cute::make_tuple(tCsS, tCrS_neg, tCrS_pos); + } + } + else if constexpr (ModeHasScales) { + Tensor sS = make_tensor(make_smem_ptr(shared_tensors.smem_scale.begin()), SmemLayoutScale{});// (BLK_M,BLK_SCALE_K,PIPE) + Tensor tCsS = mma_thread_slice.partition_A(sS); + Tensor tCrS = make_tensor(mma_thread_slice.partition_fragment_A(sS(_,_,Int<0>{})).layout()); + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return cute::make_tuple(tCsS, tCrS); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor sZ = make_tensor(make_smem_ptr(shared_tensors.smem_zero.begin()), SmemLayoutScale{});// (BLK_M,BLK_SCALE_K,PIPE) + Tensor tCsZ = mma_thread_slice.partition_A(sZ); + Tensor tCrZ = make_tensor(mma_thread_slice.partition_fragment_A(sZ(_,_,Int<0>{})).layout()); + return cute::make_tuple(tCsS, tCrS, tCsZ, tCrZ); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } + + /// Returns the tiled copy and copy views for the extra inputs. + template + CUTLASS_DEVICE + static auto retile_extra_mma_info( + TiledMma const& tiled_mma, + cute::tuple& partitioned_extra_info, + int const warp_group_thread_idx) { + + if constexpr (KernelConversionMode == ConversionMode::DirectConvert) { + // nothing to do + return cute::make_tuple(); + } + else if constexpr (ModeHasScales) { + auto smem_tiled_copy_S = make_tiled_copy_A(SmemCopyAtomScale{}, tiled_mma); + auto smem_thr_copy_S = smem_tiled_copy_S.get_thread_slice(warp_group_thread_idx); + Tensor tCrS_copy_view = smem_thr_copy_S.retile_D(cute::get<1>(partitioned_extra_info)); // (CPY,CPY_M,CPY_K) + + if constexpr (KernelConversionMode == ConversionMode::ConvertAndScale) { + return cute::make_tuple(smem_tiled_copy_S, tCrS_copy_view); + } + else if constexpr (KernelConversionMode == ConversionMode::ConvertAndScaleWithZero) { + Tensor tCrZ_copy_view = smem_thr_copy_S.retile_D(cute::get<3>(partitioned_extra_info)); // (CPY,CPY_M,CPY_K) + return cute::make_tuple(smem_tiled_copy_S, tCrS_copy_view, tCrZ_copy_view); + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } + else { + static_assert(cutlass::detail::dependent_false, "Conversion mode not handled in A -> RF path."); + } + } +}; + +} // cutlass::gemm::collective::detail diff --git a/include/cutlass/detail/dependent_false.hpp b/include/cutlass/detail/dependent_false.hpp new file mode 100644 index 0000000000..76e52d2bf8 --- /dev/null +++ b/include/cutlass/detail/dependent_false.hpp @@ -0,0 +1,86 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::detail { + +/// @brief A bool constant that depends on one or more template parameters. +/// +/// For more detailed documentation and use cases, +/// please see `dependent_false` below. +template +inline constexpr bool dependent_bool_value = Value; + +/// @brief An always-false value that depends on one or more template parameters. +/// +/// This exists because `static_assert(false);` always fails, +/// even if it occurs in the `else` branch of an `if constexpr`. +/// The following example shows how to use `dependent_false` in that case. +/// +/// @code +/// template +/// void foo (T t) +/// { +/// if constexpr (std::is_integral_v) { +/// do_integer_stuff(t); +/// } +/// else if constexpr (std::is_floating_point_v) { +/// do_floating_point_stuff(t); +/// } +/// else { +/// static_assert(dependent_false, "T must be " +/// "an integral or floating-point type."); +/// } +/// } +/// @endcode +/// +/// This implements the C++ Standard Library proposal P1830R1. +/// +/// https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2019/p1830r1.pdf +/// +/// That proposal is under review as of 2022/12/05. +/// The following link shows P1830's current review status. +/// +/// https://github.com/cplusplus/papers/issues/572 +/// +/// P2593R0 proposes an alternate solution to this problem, +/// that would change the C++ language itself. +/// +/// https://www.open-std.org/jtc1/sc22/wg21/docs/papers/2022/p2593r0.html +/// +/// For headers in this library, however, we only consider library solutions +/// as work-arounds for future C++ features. +template +inline constexpr bool dependent_false = dependent_bool_value; + +} // end namespace cutlass::detail diff --git a/include/cutlass/detail/helper_macros.hpp b/include/cutlass/detail/helper_macros.hpp new file mode 100644 index 0000000000..039f5e841a --- /dev/null +++ b/include/cutlass/detail/helper_macros.hpp @@ -0,0 +1,211 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Helper macros for the CUTLASS library +*/ + +#pragma once + +//////////////////////////////////////////////////////////////////////////////////////////////////// + + +#ifdef CUTLASS_NAMESPACE +#define concat_tok(a, b) a ## b +#define mkcutlassnamespace(pre, ns) concat_tok(pre, ns) +#define cutlass mkcutlassnamespace(cutlass_, CUTLASS_NAMESPACE) +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(__NVCC__) || (defined(__clang__) && defined(__CUDA__)) +#define CUTLASS_HOST_DEVICE __forceinline__ __device__ __host__ +#define CUTLASS_DEVICE __forceinline__ __device__ +#elif defined(__CUDACC_RTC__) +#define CUTLASS_HOST_DEVICE __forceinline__ __device__ +#define CUTLASS_DEVICE __forceinline__ __device__ +#else +#define CUTLASS_HOST_DEVICE inline +#define CUTLASS_DEVICE inline +#endif + +#if ! defined(_MSC_VER) +#define CUTLASS_LAMBDA_FUNC_INLINE __attribute__((always_inline)) +#else +#define CUTLASS_LAMBDA_FUNC_INLINE [[msvc::forceinline]] +#endif + +#define CUTLASS_HOST __host__ +#define CUTLASS_GLOBAL __global__ static + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +CUTLASS_HOST_DEVICE void __CUTLASS_UNUSED(T const &) +{ } + +#if defined(__GNUC__) + #define CUTLASS_UNUSED(expr) __CUTLASS_UNUSED(expr) +#else + #define CUTLASS_UNUSED(expr) do { ; } while (&expr != &expr) +#endif + +#ifdef _MSC_VER +// Provides support for alternative operators 'and', 'or', and 'not' +#include +#endif // _MSC_VER + +#if !defined(__CUDACC_RTC__) +#include +#endif + +#if defined(__CUDA_ARCH__) + #if defined(_MSC_VER) + #define CUTLASS_NOT_IMPLEMENTED() { printf("%s not implemented\n", __FUNCSIG__); asm volatile ("brkpt;\n"); } + #else + #define CUTLASS_NOT_IMPLEMENTED() { printf("%s not implemented\n", __PRETTY_FUNCTION__); asm volatile ("brkpt;\n"); } + #endif +#else + #if defined(_MSC_VER) + #define CUTLASS_NOT_IMPLEMENTED() assert(0 && __FUNCSIG__) + #else + #define CUTLASS_NOT_IMPLEMENTED() assert(0 && __PRETTY_FUNCTION__) + #endif +#endif + +// CUTLASS_CMATH_NAMESPACE is the namespace where code can find +// functions like isnan and log. Such functions are in +// the std namespace in host code, but in the global namespace +// in device code. +// +// The intended use case for this macro is in "using" declarations +// for making argument-dependent lookup (ADL) work in generic code. +// For example, if T is cutlass::half_t, the following code will +// invoke cutlass::isnan(half_t). If T is float, it will invoke +// std::isnan on host and ::isnan on device. (CUTLASS's support +// for NVRTC prevents it from using things in the std namespace +// in device code.) Correct use of "using" declarations can help +// avoid unexpected implicit conversions, like from half_t to float. +// +// template +// bool foo(T x) { +// using CUTLASS_CMATH_NAMESPACE :: isnan; +// return isnan(x); +// } +// +// Without this macro, one would need to write the following. +// +// template +// bool foo(T x) { +// #if defined(__CUDA_ARCH__) +// using ::isnan; +// #else +// using std::isnan; +// #endif +// return isnan(x); +// } + +#if defined(__CUDA_ARCH__) +# define CUTLASS_CMATH_NAMESPACE +#else +# define CUTLASS_CMATH_NAMESPACE std +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { + + +#ifndef CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED +#define CUTLASS_CONV_UNIT_TEST_RIGOROUS_SIZE_ENABLED 0 +#endif + + +// CUDA 10.1 introduces the mma instruction +#if !defined(CUTLASS_ENABLE_TENSOR_CORE_MMA) +#define CUTLASS_ENABLE_TENSOR_CORE_MMA 0 +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#define CUTLASS_ASSERT(x) assert(x) + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// CUTLASS_PRAGMA_(UNROLL|NO_UNROLL) optimization directives for the CUDA compiler. +#if defined(__CUDA_ARCH__) && !defined(__INTELLISENSE__) + #if defined(__CUDACC_RTC__) || (defined(__clang__) && defined(__CUDA__)) + #define CUTLASS_PRAGMA_UNROLL _Pragma("unroll") + #define CUTLASS_PRAGMA_NO_UNROLL _Pragma("unroll 1") + #else + #define CUTLASS_PRAGMA_UNROLL #pragma unroll + #define CUTLASS_PRAGMA_NO_UNROLL #pragma unroll 1 + #endif + + #define CUTLASS_GEMM_LOOP CUTLASS_PRAGMA_NO_UNROLL + +#else + + #define CUTLASS_PRAGMA_UNROLL + #define CUTLASS_PRAGMA_NO_UNROLL + #define CUTLASS_GEMM_LOOP + +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if !defined(__CUDACC_RTC__) +#define CUTLASS_THREAD_LOCAL thread_local +#else +#define CUTLASS_THREAD_LOCAL +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +#if defined(_MSVC_LANG) +# define CUTLASS_CPLUSPLUS _MSVC_LANG +#else +# define CUTLASS_CPLUSPLUS __cplusplus +#endif + +#if (201700L <= CUTLASS_CPLUSPLUS) +#define CUTLASS_CONSTEXPR_IF_CXX17 constexpr +#define CUTLASS_CXX17_OR_LATER 1 +#else +#define CUTLASS_CONSTEXPR_IF_CXX17 +#define CUTLASS_CXX17_OR_LATER 0 +#endif + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +}; // namespace cutlass + +//////////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/detail/layout.hpp b/include/cutlass/detail/layout.hpp new file mode 100644 index 0000000000..cbed61f683 --- /dev/null +++ b/include/cutlass/detail/layout.hpp @@ -0,0 +1,406 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/layout.hpp" +#include "cute/pointer_sparse.hpp" // cute::is_sparse +#include "cute/swizzle.hpp" // cute::Swizzle +#include "cute/swizzle_layout.hpp" // cute::detail::get_swizzle_portion +#include "cute/util/type_traits.hpp" +#include "cute/arch/copy_sm90_tma.hpp" +#include "cutlass/layout/matrix.h" +#include "cutlass/layout/tensor.h" +#include "cutlass/numeric_types.h" +#include "cutlass/detail/collective.hpp" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::detail { + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// For each cutlass::layout, provides its corresponding cute stride types, 64b by default + +template +struct TagToStrideA { + using type = L; +}; + +// Maps to modes [M, K, L] +template <> +struct TagToStrideA { + using type = cute::Stride, int64_t>; + using tag = layout::RowMajor; +}; + +// Maps to modes [M, K, L] +template <> +struct TagToStrideA { + using type = cute::Stride, int64_t, int64_t>; + using tag = layout::ColumnMajor; +}; + +template +struct TagToStrideB { + using type = L; +}; + +// Maps to modes [N, K, L] +template <> +struct TagToStrideB { + using type = cute::Stride, int64_t, int64_t>; + using tag = layout::RowMajor; +}; + +// Maps to modes [N, K, L] +template <> +struct TagToStrideB { + using type = cute::Stride, int64_t>; + using tag = layout::ColumnMajor; +}; + +// For each cutlass::layout *, provides its corresponding cute stride types, 64b by default +// Used by pointer array and grouped gemm +// Maps to modes [M, K, L] +template <> +struct TagToStrideA { + using UnderlyingType = cute::Stride, cute::Int<0>>; + using type = UnderlyingType*; + using tag = layout::RowMajor; +}; + +// Maps to modes [M, K, L] +template <> +struct TagToStrideA { + using UnderlyingType = cute::Stride, int64_t, cute::Int<0>>; + using type = UnderlyingType*; + using tag = layout::ColumnMajor; +}; + +// Maps to modes [N, K, L] +template <> +struct TagToStrideB { + using UnderlyingType = cute::Stride, int64_t, cute::Int<0>>; + using type = UnderlyingType*; + using tag = layout::RowMajor; +}; + +// Maps to modes [N, K, L] +template <> +struct TagToStrideB { + using UnderlyingType = cute::Stride, cute::Int<0>>; + using type = UnderlyingType*; + using tag = layout::ColumnMajor; +}; + +// Maps to modes [M, N, L] +template +struct TagToStrideC : TagToStrideA { }; + +// Conv: Maps to modes ((P,N), C, _0) for compatiblity with GEMM epilogues expecting a batch mode stride +template <> +struct TagToStrideC { + using type = cute::Stride, cute::Int<1>, cute::Int<0>>; +}; + +// Conv: Maps to modes ((P,Q,N), C, _0) for compatiblity with GEMM epilogues expecting a batch mode stride +template <> +struct TagToStrideC { + using type = cute::Stride, cute::Int<1>, cute::Int<0>>; +}; + +// Conv: Maps to modes ((P,Q,Z,N), C, _0) for compatiblity with GEMM epilogues expecting a batch mode stride +template <> +struct TagToStrideC { + using type = cute::Stride, cute::Int<1>, cute::Int<0>>; +}; + +// Conv: Maps to modes (K, (C,S), _0) for compatiblity with GEMM epilogues expecting a batch mode stride +template <> +struct TagToStrideC { + using type = cute::Stride, int64_t>, cute::Int<0>>; +}; + +// Conv: Maps to modes (K, (C,S,R), _0) for compatiblity with GEMM epilogues expecting a batch mode stride +template <> +struct TagToStrideC { + using type = cute::Stride, int64_t, int64_t>, cute::Int<0>>; +}; + +// Conv: Maps to modes (K, (C,S,R,T), _0) for compatiblity with GEMM epilogues expecting a batch mode stride +template <> +struct TagToStrideC { + using type = cute::Stride, int64_t, int64_t, int64_t>, cute::Int<0>>; +}; + +// Conv: Maps to modes ((C,S), K, _0) for compatiblity with GEMM epilogues expecting a batch mode stride +template <> +struct TagToStrideC { + using type = cute::Stride, int64_t>, int64_t, cute::Int<0>>; +}; + +// Conv: Maps to modes ((C,S,R), K, _0) for compatiblity with GEMM epilogues expecting a batch mode stride +template <> +struct TagToStrideC { + using type = cute::Stride, int64_t, int64_t>, int64_t, cute::Int<0>>; +}; + +// Conv: Maps to modes ((C,S,R,T), K, _0) for compatiblity with GEMM epilogues expecting a batch mode stride +template <> +struct TagToStrideC { + using type = cute::Stride, int64_t, int64_t, int64_t>, int64_t, cute::Int<0>>; +}; + +// Convenience aliases +template +using TagToStrideA_t = typename TagToStrideA::type; + +template +using TagToStrideB_t = typename TagToStrideB::type; + +template +using TagToStrideC_t = typename TagToStrideC::type; + +//////////////////////////////////////////////////////////////////////////////////////////////////// +// For 2.x compatibility APIs, provide stride->layout tag mappers + +template +constexpr bool +is_major(Stride = {}) { + // Account for stride types with and without batch mode and batch modes with static zero stride + return cute::is_constant<1, decltype(cute::front(cute::get(cute::remove_pointer_t{})))>::value; +} + +template +constexpr bool +is_major(cute::Layout = {}) { + return is_major(Stride{}); +} + +// Note : This method can be used for deducing the Layout Tag of A, C, D Matrices +template +constexpr +auto +stride_to_layout_tag_A() { + using InternalStrideA = cute::remove_pointer_t; + if constexpr (cute::is_layout::value) { + return stride_to_layout_tag_A(); + } + else if constexpr (is_major<0, StrideA>()) { // M major + return layout::ColumnMajor{}; + } + // Specialize for sparse layout + else if constexpr (cute::get<0>(InternalStrideA{}) == cute::_2{} && + cute::rank(cute::get<1>(InternalStrideA{})) == 2 && + cute::is_same_v(InternalStrideA{}))>>) { + return layout::ColumnMajor{}; + } + else { // K major + return layout::RowMajor{}; + } + + CUTE_GCC_UNREACHABLE; +} + +template +constexpr +auto +stride_to_layout_tag_B() { + using InternalStrideB = cute::remove_pointer_t; + if constexpr (cute::is_layout::value) { + return stride_to_layout_tag_B(); + } + else if constexpr (is_major<0, StrideB>()) { // N major + return layout::RowMajor{}; + } + else { // K major + return layout::ColumnMajor{}; + } + + CUTE_GCC_UNREACHABLE; +} + +template +constexpr +auto +stride_to_layout_tag_C() { + using InternalStrideC = cute::remove_pointer_t; + if constexpr (cute::is_layout::value) { + return stride_to_layout_tag_C(); + } + else if constexpr (is_major<0, StrideC>()) { // M major + return layout::ColumnMajor{}; + } + else { // N major + return layout::RowMajor{}; + } + + CUTE_GCC_UNREACHABLE; +} + +// Utilities to map Stride back on to their corresponding layout tags +template +struct StrideToLayoutTagA { + using type = decltype(detail::stride_to_layout_tag_A()); +}; + +template +struct StrideToLayoutTagB { + using type = decltype(detail::stride_to_layout_tag_B()); +}; + +template +struct StrideToLayoutTagC { + using type = decltype(detail::stride_to_layout_tag_C()); +}; + +// Convenience aliases +template +using StrideToLayoutTagA_t = typename StrideToLayoutTagA::type; + +template +using StrideToLayoutTagB_t = typename StrideToLayoutTagB::type; + +template +using StrideToLayoutTagC_t = typename StrideToLayoutTagC::type; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +// Inspects a tiled copy and whether its copy engine is TMA or not +template +constexpr bool is_tma_copy_engine() { + if constexpr (cute::is_void_v) { + return false; + } + else { + if constexpr ( cute::is_base_of_v + || cute::is_base_of_v + || cute::is_base_of_v + || cute::is_base_of_v + || cute::is_base_of_v + || cute::is_base_of_v + ) { + return true; + } + } + return false; +} + +template +struct RawDtype { using type = X; }; + +template +struct RawDtype> { using type = typename X::raw_type; }; + + +// Inspects a TiledCopy and returns its alignment in terms of element count +template +constexpr int +get_alignment_count_from_gmem_tiled_copy() { + + if constexpr (cute::is_void_v) { + return 1; + } + + // Account for ElementC = void kernels + else if constexpr (cute::is_void_v) { + return 0; + } + + else { + // For TMA tiled copies, we know the alignment has to be 128 bits + if constexpr (is_tma_copy_engine()) { + // For sparse MMA, alignment in logical elements is increased by sparsity factor + if constexpr (cute::is_sparse_v) { + return 128 / sizeof_bits::value * ElementMma::sparsity; + } + return 128 / sizeof_bits::value; + } + else { + // For non-TMA tiled copies, TiledCopy holds the alignment count directly in its TiledShape_MN + return GmemTiledCopy::NumValSrc; + } + } +} + +// Return alignment bit requirements for the GEMM inputs. +template < + class ElementType +> +constexpr int +get_input_alignment_bits() { + return 128; +} + +// Return alignment bit requirements for the GEMM outputs. +template +constexpr int +get_output_alignment_bits() { + return 128; +} + +// Check if tensor layout satisfies a given major alignment +template +CUTLASS_HOST_DEVICE constexpr +bool +check_alignment(cute::Layout const& layout) { + // Condition: shape must divide by Alignment without rounding + bool shape_check = cute::size(layout.shape()) == Alignment * cute::size(cute::upcast(layout)); + // Condition: every dynamic stride must be a multiple of Alignment + bool stride_check = cute::all_of(cute::flatten(layout.stride()), [](auto s){ return cute::is_static::value || (s % Alignment == 0); }); + return shape_check && stride_check; +} + +// Check if tensor layout satisfies a given major alignment +template +CUTLASS_HOST_DEVICE constexpr +bool +check_alignment(Shape const& shape, Stride const& stride) { + return check_alignment(cute::make_layout(shape, stride)); +} + +template +CUTLASS_HOST_DEVICE constexpr +size_t +alignment_for_swizzle(cute::Swizzle) { + static_assert(B >= 0 and M >= 0); + return size_t(1) << size_t(B + M + cute::abs(S)); +} + +template +CUTLASS_HOST_DEVICE constexpr +size_t +alignment_for_swizzle(Layout layout) { + return alignment_for_swizzle(cute::detail::get_swizzle_portion(layout)); +} + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::detail diff --git a/include/cutlass/detail/mainloop_fusion_helper_scale_factor.hpp b/include/cutlass/detail/mainloop_fusion_helper_scale_factor.hpp new file mode 100644 index 0000000000..914443dd0d --- /dev/null +++ b/include/cutlass/detail/mainloop_fusion_helper_scale_factor.hpp @@ -0,0 +1,75 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Mainloop Fusion configs specific for scale factors +*/ + +#pragma once + +#include // cute::void_t + +namespace cutlass::detail { + +///////////////////////////////////////////////////////////////////////////////////////////////// +template +struct ElementSFType { + using type = void; +}; + +template +struct ElementSFType> { + using type = typename CollectiveMainloop::ElementSF; +}; + +template +struct LayoutSFAType { + using type = void; +}; + +template +struct LayoutSFAType> { + using type = typename CollectiveMainloop::LayoutSFA; +}; + +template +struct LayoutSFBType { + using type = void; +}; + +template +struct LayoutSFBType> { + using type = typename CollectiveMainloop::LayoutSFB; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::detail diff --git a/include/cutlass/detail/mma.hpp b/include/cutlass/detail/mma.hpp new file mode 100644 index 0000000000..0e491b9c40 --- /dev/null +++ b/include/cutlass/detail/mma.hpp @@ -0,0 +1,71 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/arch/mma.h" +#include "cute/layout.hpp" + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::detail { + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct IsSparseTensorOp : cute::false_type { }; + +// TiledMma for sparse must have ValTypeE +template +struct IsSparseTensorOp> + : cute::true_type { }; + +// The following metafunction is used to extract the OperatorClass from a cutlass 3.x kernel. +template +struct get_operator_class { + static constexpr bool is_sparse_op = IsSparseTensorOp::value; + static constexpr bool is_tensor_op = cute::size<0>(typename TiledMma::AtomShape_MNK{}) >= 8; + using type = cute::conditional_t< + is_tensor_op, + cute::conditional_t< + is_sparse_op, + cutlass::arch::OpClassSparseTensorOp, + cutlass::arch::OpClassTensorOp + >, + cutlass::arch::OpClassSimt + >; +}; + +template +using get_operator_class_t = typename get_operator_class::type; + +//////////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::detail diff --git a/include/cutlass/device_kernel.h b/include/cutlass/device_kernel.h index 1de33024b9..cc7caede49 100644 --- a/include/cutlass/device_kernel.h +++ b/include/cutlass/device_kernel.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -34,20 +34,46 @@ #pragma once -#include "cutlass/cutlass.h" +#include // CUTLASS_HOST_DEVICE +#include // uint64_t + +// __grid_constant__ was introduced in CUDA 11.7. +#if ((__CUDACC_VER_MAJOR__ >= 12) || ((__CUDACC_VER_MAJOR__ == 11) && (__CUDACC_VER_MINOR__ >= 7))) && !CUTLASS_CLANG_CUDA +# define CUTLASS_GRID_CONSTANT_SUPPORTED +#endif + +// __grid_constant__ can be enabled only on SM70+ +#if defined(CUTLASS_GRID_CONSTANT_SUPPORTED) && defined(__CUDA_ARCH__) && (__CUDA_ARCH__ >= 700) +# define CUTLASS_GRID_CONSTANT_ENABLED +#endif + +#if ! defined(CUTLASS_GRID_CONSTANT) +# if defined(CUTLASS_GRID_CONSTANT_ENABLED) +# define CUTLASS_GRID_CONSTANT __grid_constant__ +# else +# define CUTLASS_GRID_CONSTANT +# endif +#endif + //////////////////////////////////////////////////////////////////////////////// namespace cutlass { +template struct Type2Type { using type=T; }; +// using the simple type to replace the complex type to reduce this symbol size +template struct GetUnderlyingKernel : public Type2Type {}; +template class Wrapper > struct GetUnderlyingKernel> : public Wrapper {}; +template using GetUnderlyingKernel_t = typename GetUnderlyingKernel::type; + + //////////////////////////////////////////////////////////////////////////////// /// Generic CUTLASS kernel template. template -__global__ +CUTLASS_GLOBAL void Kernel(typename Operator::Params params) { // Dynamic shared memory base pointer extern __shared__ int SharedStorageBase[]; - // Declare pointer to dynamic shared memory. typename Operator::SharedStorage *shared_storage = reinterpret_cast(SharedStorageBase); @@ -55,8 +81,48 @@ void Kernel(typename Operator::Params params) { Operator op; op(params, *shared_storage); + cutlass::arch::synclog_print(); } + +/// Generic CUTLASS kernel template. +template +CUTLASS_GLOBAL +void Kernel2(typename Operator::Params params) { + // Dynamic shared memory base pointer + extern __shared__ int SharedStorageBase[]; + // Declare pointer to dynamic shared memory. + typename Operator::SharedStorage *shared_storage = + reinterpret_cast(SharedStorageBase); + + Operator::invoke(params, *shared_storage); + cutlass::arch::synclog_print(); + +} + + +//////////////////////////////////////////////////////////////////////////////// +// +// 3.0 specific launch +// //////////////////////////////////////////////////////////////////////////////// -} /// namespace cutlass +/// Generic CUTLASS kernel template. +template +CUTLASS_GLOBAL +#ifdef __CUDACC__ +// Enclosing this in __CUDACC__ suppresses MSVC warnings. +__launch_bounds__(Operator::MaxThreadsPerBlock, Operator::MinBlocksPerMultiprocessor) +#endif // __CUDACC__ +void device_kernel(CUTLASS_GRID_CONSTANT typename Operator::Params const params) +{ + // Dynamic shared memory base pointer + extern __shared__ char smem[]; + Operator op; + op(params, smem); + cutlass::arch::synclog_print(); + +} + +//////////////////////////////////////////////////////////////////////////////// +} /// namespace cutlass diff --git a/include/cutlass/epilogue/collective/builders/sm90_builder.inl b/include/cutlass/epilogue/collective/builders/sm90_builder.inl new file mode 100644 index 0000000000..720dcc008a --- /dev/null +++ b/include/cutlass/epilogue/collective/builders/sm90_builder.inl @@ -0,0 +1,813 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cute/atom/mma_traits_sm90.hpp" +#include "cute/atom/mma_traits_sm90_gmma.hpp" +#include "cute/atom/copy_traits_sm90.hpp" + +#include "cutlass/detail/dependent_false.hpp" +#include "cutlass/detail/layout.hpp" +#include "cutlass/gemm/collective/builders/sm90_common.inl" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/collective_epilogue.hpp" +#include "cutlass/epilogue/collective/builders/sm90_common.inl" +#include "cutlass/epilogue/thread/linear_combination.h" +#include "cutlass/epilogue/thread/linear_combination_generic.h" +#include "cutlass/epilogue/thread/linear_combination_bias_elementwise.h" +#include "cutlass/epilogue/fusion/callbacks.hpp" +#include "cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp" + +#if defined(__CUDACC_RTC__) +#include +#else +#include +#endif + +/////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::collective { + +/////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +// Returns the parameterized dispatch policy for the TMA epilogue +template +constexpr auto +sm90_get_tma_dispatch_policy() { + using namespace cute; + + constexpr int EpiTiles = size(shape_div(take<0,2>(TileShapeMNK{}), EpilogueTileMN{})); + constexpr int FragmentSize = size(EpilogueTileMN{}) / (detail::sm90_is_cooperative_v ? 256 : 128); + // 8b residuals load fast and consume little smem, so the perf cost of waiting on stores to finish outweighs the cost of extra allocation + constexpr bool ReuseSmem = (sizeof_bits_v == sizeof_bits_v) && (sizeof_bits_v > 8); + // TMA store delay performs worse with residual loads and compilicates tensormap updates for Ptr-Array GEMMs + constexpr bool DelayTmaStore = is_void_v && !detail::sm90_is_ptr_array_tma_v; + constexpr int StagesD = cute::min(EpiTiles, 2); + constexpr int StagesC = ReuseSmem ? cute::max(cute::min(EpiTiles, 4), StagesD+1) + : cute::min(EpiTiles, 4); + + if constexpr (detail::sm90_is_ptr_array_tma_v) { + return Sm90PtrArrayTmaWarpSpecialized{}; + } + else { + return Sm90TmaWarpSpecialized{}; + } +} + +// Returns the smem layout atom to be used for C or D matrix +template +constexpr auto +sm90_get_epilogue_smem_swizzle_layout_atom() { + using namespace cute; + + // ColMajor C/D (M-major) + if constexpr (cutlass::gemm::detail::is_major<0>(GmemStrideType{})) { + return cutlass::gemm::collective::detail::ss_smem_selector< + cute::GMMA::Major::MN, Element, decltype(get<0>(EpilogueTile_MN{})), decltype(get<1>(EpilogueTile_MN{})) + >(); + } + // RowMajor C/D (N-major) + else if constexpr (cutlass::gemm::detail::is_major<1>(GmemStrideType{})) { + return cutlass::gemm::collective::detail::ss_smem_selector< + cute::GMMA::Major::K , Element, decltype(get<0>(EpilogueTile_MN{})), decltype(get<1>(EpilogueTile_MN{})) + >(); + } + else { + static_assert(cutlass::detail::dependent_false, "Unsupported gmem layout."); + } +} + +// Attempts to compute a reasonable epilogue tile based on block tile shape or allows the user to provide one. +template +constexpr auto +sm90_compute_tile_shape_or_override() { + if constexpr (cute::is_same_v) { + auto epi_tile = [&] () { + if constexpr (detail::sm90_is_cooperative_v) { + auto tile_m = cute::min(_128{}, size<0>(TileShape_MNK{})); + auto tile_n = cute::min(_32{}, size<1>(TileShape_MNK{})); + return make_shape(tile_m, tile_n); + } + else if constexpr (detail::sm90_is_warp_specialized_v) { + constexpr int N_perf = sizeof_bits_v == 8 ? 64 : 32; + auto tile_m = cute::min(_64{}, size<0>(TileShape_MNK{})); + auto tile_n = cute::min(Int{}, size<1>(TileShape_MNK{})); + return make_shape(tile_m, tile_n); + } + else { + static_assert(cutlass::detail::dependent_false, "Unsupported schedule."); + } + }(); + + return cute::transform(epi_tile, seq<0,1>{}, + [] (auto epi_tiler, auto I) { + auto cta_tiler = make_layout(get(TileShape_MNK{})); + // This is a multimodal CTA tiler, transform before returning + if constexpr (depth(cta_tiler) > 0) { + // This is an implicit multimodal tiler, match profile and return + if constexpr (tuple_size_v == 1) { + return make_tile(epi_tiler); + } + // This is an explicit multimodal tiler, compose out epi tiler + else { + return composition(cta_tiler, epi_tiler); + } + } + // This is a flat CTA tiler, no need for transformation + else { + return epi_tiler; + } + }); + } + else if constexpr (cute::is_tuple::value) { + EpilogueTileType epi_tile; + constexpr int M = size<0>(shape(epi_tile)); + constexpr int N = size<1>(shape(epi_tile)); + + static_assert(!is_layout::value, "EpilogueTile must be a cute::Tile or cute::Shape"); + static_assert(M == 64 && detail::sm90_is_warp_specialized_v || + M == 128 && detail::sm90_is_cooperative_v, "Unsupported tile shape"); + static_assert(N % 16 == 0, "Unsupported tile shape"); + + return epi_tile; + } + else { + static_assert(cutlass::detail::dependent_false, "Invalid type for EpilogueTileType."); + } +} + +// callbacks builder with TMA aux out +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class FusionOp, + class TileShape_MNK, + class EpilogueTile_MN, + class ElementAccumulator +> +struct CallbacksBuilder< + Sm90TmaWarpSpecialized, + FusionOp, + TileShape_MNK, + EpilogueTile_MN, + ElementAccumulator, + cute::enable_if_t<(FusionOp::IsAuxOutSupported ^ FusionOp::IsAuxInSupported) // only one aux tensor + && not cute::is_subbyte_v> +> { + using GmemStrideTypeAux = gemm::TagToStrideC_t; + using SmemLayoutAtomAux = decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom< + GmemStrideTypeAux, typename FusionOp::ElementAux, EpilogueTile_MN>()); + using CopyOpR2S = decltype(detail::sm90_get_smem_store_op_for_accumulator< + GmemStrideTypeAux, typename FusionOp::ElementAux>()); + using CopyOpS2R = decltype(detail::sm90_get_smem_load_op_for_source< + GmemStrideTypeAux, typename FusionOp::ElementAux>()); + using SmemCopyOpAux = cute::conditional_t; + + using Callbacks = fusion::FusionCallbacks< + Sm90TmaWarpSpecialized, + FusionOp, TileShape_MNK, EpilogueTile_MN, + SmemLayoutAtomAux, SmemCopyOpAux + >; +}; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class FusionOp, + class TileShape_MNK, + class EpilogueTile_MN, + class ElementAccumulator +> +struct CallbacksBuilder< + Sm90TmaWarpSpecialized, + FusionOp, + TileShape_MNK, + EpilogueTile_MN, + ElementAccumulator, + cute::enable_if_t<(FusionOp::IsAuxOutSupported ^ FusionOp::IsAuxInSupported) // only one aux tensor + && sizeof_bits_v == 1> +> { + using Callbacks = fusion::FusionCallbacks< + Sm90TmaWarpSpecialized, + FusionOp, TileShape_MNK, EpilogueTile_MN, + Layout<_1,_0>, DefaultCopy // aux bit tensor doesn't use smem + >; +}; + +// Helper for building TMA warp-specialized collective epilogues, specialized by +// the fusion operation performed and the dispatch policy to use. +template < + class TileShape_MNK, + class EpilogueTile_MN, + class ElementAccumulator, + class ElementCompute, + class ElementC_, + class GmemLayoutTagC_, + int AlignmentC, + class ElementD_, + class GmemLayoutTagD, + int AlignmentD, + class FusionOpOrCallbacks, + class DispatchPolicy +> +struct Sm90TmaBuilderImpl { + // Passing void D disables destination store + smem allocation + using ElementD = cute::conditional_t, + fusion::get_element_aux_t, ElementD_>; + + // Passing void C disables source load + smem allocation + using ElementC = cute::conditional_t,ElementD,ElementC_>; // prevents void ref breakages + using GmemLayoutTagC = cute::conditional_t,GmemLayoutTagD,GmemLayoutTagC_>; + + using GmemStrideTypeC = cutlass::detail::TagToStrideC_t; + using GmemStrideTypeD = cutlass::detail::TagToStrideC_t; + + using UnderlyingGmemStrideTypeC = cute::remove_pointer_t; + using UnderlyingGmemStrideTypeD = cute::remove_pointer_t; + + using CopyOpS2G = cute::conditional_t, + SM90_TMA_STORE_IM2COL, + SM90_TMA_STORE + >; + using CopyOpG2S = cute::conditional_t, + SM90_TMA_LOAD_IM2COL, + SM90_TMA_LOAD + >; + + // Get the smallest tiled copy we can use to retile the accumulators + using CopyAtomC = Copy_Atom; + // Get register to register tiled copy that happen before shared memory store. + // Apply void as no register transform op needed currently. + using CopyOpR2R = void; + + // TMA builder allows for passing callbacks directly, which is either a fusion::FusionCallbacks + // instance or a direct visitor implementation, e.g. fusion::Sm90LinearCombination + using FusionCallbacks = + typename CallbacksBuilder< + DispatchPolicy, + FusionOpOrCallbacks, + TileShape_MNK, + EpilogueTile_MN, + ElementAccumulator + >::Callbacks; + + using CollectiveOp = cutlass::epilogue::collective::CollectiveEpilogue< + DispatchPolicy, + TileShape_MNK, + EpilogueTile_MN, + ElementC_, // Need to pass void through to expose via GemmUniversal + GmemStrideTypeC, + ElementD_, + GmemStrideTypeD, + FusionCallbacks, + CopyOpG2S, + decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom()), + decltype(detail::sm90_get_smem_load_op_for_source()), + CopyOpS2G, + decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom()), + decltype(detail::sm90_get_smem_store_op_for_accumulator()), + CopyAtomC, + CopyOpR2R + >; +}; + +/////////////////////////////////////////////////////////////////////////////// +// Descriptor classes for defining EVT nodes +// Some of the epilogue visitor nodes require non-intuitive template arguments +// such as CopyOpS2R for AuxLoad node. Traditionaly, these are resolved by the +// builder classes. Here we provide a set of descriptor classes that resolve +// these template arguments from more intuitive types such as Stride, Layout + +// Get TileShape, EpilogueTile, Dispatch Policy, StagesC, and STagesD +template< + typename TileShape_MNK, + typename EpilogueTileType, + typename ElementC, + typename ElementD, + typename Schedule +> +struct EpilogueDescriptor { + using TileShape = TileShape_MNK; + using EpilogueTile = + decltype( + detail::sm90_compute_tile_shape_or_override< + ElementD, EpilogueTileType, Schedule, TileShape_MNK + >() + ); + using DispatchPolicy = + decltype( + detail::sm90_get_tma_dispatch_policy< + TileShape_MNK, EpilogueTile, + ElementC, ElementD, Schedule + >() + ); + constexpr static int StagesC = DispatchPolicy::StagesC; + constexpr static int StagesD = DispatchPolicy::StagesD; +}; + +// Get Stride, SmemLayout, and CopyOpS2R for AuxLoad node +template< + typename EpilogueDescriptor, + typename StrideOrLayoutTag, + typename ElementAux +> +struct AuxLoadDescriptor { + constexpr static int Stages = EpilogueDescriptor::StagesC; + using EpilogueTile = typename EpilogueDescriptor::EpilogueTile; + using Element = ElementAux; + using Stride = cutlass::detail::TagToStrideC_t; + using SmemLayoutAtom = + decltype( + detail::sm90_get_epilogue_smem_swizzle_layout_atom< + Stride, ElementAux, typename EpilogueDescriptor::EpilogueTile + >() + ); + using CopyOpS2R = + decltype(detail::sm90_get_smem_load_op_for_source()); +}; + +// Get Stride, SmemLayout, and CopyOpS2R for AuxStore node +template< + typename EpilogueDescriptor, + typename StrideOrLayoutTag, + typename ElementAux +> +struct AuxStoreDescriptor { + constexpr static int Stages = EpilogueDescriptor::StagesD; + using EpilogueTile = typename EpilogueDescriptor::EpilogueTile; + using Element = ElementAux; + using Stride = cutlass::detail::TagToStrideC_t; + using SmemLayoutAtom = + decltype( + detail::sm90_get_epilogue_smem_swizzle_layout_atom< + Stride, ElementAux, typename EpilogueDescriptor::EpilogueTile + >() + ); + using CopyOpR2S = + decltype(detail::sm90_get_smem_store_op_for_accumulator()); +}; + +} // namespace detail + +/////////////////////////////////////////////////////////////////////////////// + +// No-smem builder +template < + class OpClass, + class TileShape_MNK, + class ClusterShape_MNK, + class EpilogueTileType, + class ElementAccumulator, + class ElementCompute, + class ElementC_, + class GmemLayoutTagC_, + int AlignmentC, + class ElementD, + class GmemLayoutTagD, + int AlignmentD, + class Schedule, + FloatRoundStyle RoundStyle +> +struct CollectiveBuilder< + arch::Sm90, + OpClass, + TileShape_MNK, + ClusterShape_MNK, + EpilogueTileType, + ElementAccumulator, + ElementCompute, + ElementC_, + GmemLayoutTagC_, + AlignmentC, + ElementD, + GmemLayoutTagD, + AlignmentD, + Schedule, + fusion::LinearCombination, + cute::enable_if_t || + cute::is_same_v || + cute::is_same_v >> { + + // Passing void C disables source load + using ElementC = cute::conditional_t, + ElementD, ElementC_>; // prevents cute breakages + using GmemLayoutTagC = cute::conditional_t, + GmemLayoutTagD, GmemLayoutTagC_>; + static constexpr thread::ScaleType::Kind ScaleType = cute::is_void_v ? + thread::ScaleType::OnlyAlphaScaling : thread::ScaleType::Default; + + static constexpr int FragmentSize = 1; + using ThreadOp = thread::LinearCombination< + ElementD, FragmentSize, ElementAccumulator, ElementCompute, + ScaleType, RoundStyle, ElementC>; + + using CollectiveOp = cute::conditional_t< + cute::is_same_v, + cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter< + cutlass::epilogue::collective::DefaultEpilogue< + cutlass::detail::TagToStrideC_t, + cutlass::detail::TagToStrideC_t, + ThreadOp, + cutlass::gemm::EpilogueDefault>>, + // Epilogue for Ptr-Array and Grouped Gemm + cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter< + cutlass::epilogue::collective::DefaultEpilogueArray< + cutlass::detail::TagToStrideC_t, + cutlass::detail::TagToStrideC_t, + ThreadOp, + Schedule>> + >; +}; + +// Tma warp-specialized builder +template < + class OpClass, + class TileShape_MNK, + class ClusterShape_MNK, + class EpilogueTileType, + class ElementAccumulator, + class ElementCompute, + class ElementC, + class GmemLayoutTagC, + int AlignmentC, + class ElementD_, + class GmemLayoutTagD, + int AlignmentD, + class Schedule, + class FusionOperation +> +struct CollectiveBuilder< + arch::Sm90, + OpClass, + TileShape_MNK, + ClusterShape_MNK, + EpilogueTileType, + ElementAccumulator, + ElementCompute, + ElementC, + GmemLayoutTagC, + AlignmentC, + ElementD_, + GmemLayoutTagD, + AlignmentD, + Schedule, + FusionOperation, + cute::enable_if_t || + cute::is_same_v || + detail::sm90_is_ptr_array_tma_v>> { +private: + using ElementD = cute::conditional_t, + fusion::get_element_aux_t, ElementD_>; + using EpilogueTile_MN = + decltype(detail::sm90_compute_tile_shape_or_override()); + using DispatchPolicy = + decltype(detail::sm90_get_tma_dispatch_policy()); + +public: + using CollectiveOp = + typename detail::Sm90TmaBuilderImpl< + TileShape_MNK, + EpilogueTile_MN, + ElementAccumulator, + ElementCompute, + ElementC, + GmemLayoutTagC, + AlignmentC, + ElementD_, + GmemLayoutTagD, + AlignmentD, + FusionOperation, + DispatchPolicy + >::CollectiveOp; +}; + +// Auto builder +template < + class OpClass, + class TileShape_MNK, + class ClusterShape_MNK, + class EpilogueTileType, + class ElementAccumulator, + class ElementCompute, + class ElementC, + class GmemLayoutTagC, + int AlignmentC, + class ElementD, + class GmemLayoutTagD, + int AlignmentD, + class FusionOperation +> +struct CollectiveBuilder< + arch::Sm90, + OpClass, + TileShape_MNK, + ClusterShape_MNK, + EpilogueTileType, + ElementAccumulator, + ElementCompute, + ElementC, + GmemLayoutTagC, + AlignmentC, + ElementD, + GmemLayoutTagD, + AlignmentD, + EpilogueScheduleAuto, + FusionOperation, + void> { +private: + static_assert(cute::is_same_v>, + "Auto schedule doesn't support fusion. Use one of the TmaWarpSpecialized schedules instead."); + + // Pick No-Smem epilogue as the Auto Epilogue Schedule (Auto schedules do not guarantee best performance) + // since TMA epilogues are not compatible with non-TMA non-WS mainloops + using EpilogueSchedule = NoSmemWarpSpecialized; + using _CollectiveBuilder = CollectiveBuilder< + arch::Sm90, + OpClass, + TileShape_MNK, + ClusterShape_MNK, + EpilogueTileType, + ElementAccumulator, + ElementCompute, + ElementC, + GmemLayoutTagC, + AlignmentC, + ElementD, + GmemLayoutTagD, + AlignmentD, + EpilogueSchedule, + FusionOperation + >; + +public: + using CollectiveOp = typename _CollectiveBuilder::CollectiveOp; +}; + +// DEPRECATED Tma warp-specialized builder for elementwise fusion +template < + class OpClass, + class TileShape_MNK, + class ClusterShape_MNK, + class EpilogueTileType, + class ElementAccumulator, + class ElementCompute, + class ElementC, + class GmemLayoutTagC, + int AlignmentC, + class ElementD, + class GmemLayoutTagD, + int AlignmentD, + class Schedule, + class UnusedFusionOp +> +struct [[deprecated("Use TmaWarpSpecialized with fusion::LinCombEltAct instead")]] +CollectiveBuilder< + arch::Sm90, + OpClass, + TileShape_MNK, + ClusterShape_MNK, + EpilogueTileType, + ElementAccumulator, + ElementCompute, + ElementC, + GmemLayoutTagC, + AlignmentC, + ElementD, + GmemLayoutTagD, + AlignmentD, + Schedule, + UnusedFusionOp, + cute::enable_if_t || + cute::is_base_of_v >> { +private: + using FusionOp = + fusion::LinCombEltAct; + using ImplSchedule = + cute::conditional_t, + TmaWarpSpecialized, TmaWarpSpecializedCooperative>; + +public: + using CollectiveOp = + typename CollectiveBuilder< + arch::Sm90, + OpClass, + TileShape_MNK, + ClusterShape_MNK, + EpilogueTileType, + ElementAccumulator, + ElementCompute, + ElementC, + GmemLayoutTagC, + AlignmentC, + ElementD, + GmemLayoutTagD, + AlignmentD, + ImplSchedule, + FusionOp + >::CollectiveOp; +}; + +// DEPRECATED Tma warp-specialized builder for bias + elementwise fusion +template < + class OpClass, + class TileShape_MNK, + class ClusterShape_MNK, + class EpilogueTileType, + class ElementAccumulator, + class ElementCompute, + class ElementC_, + class GmemLayoutTagC_, + int AlignmentC, + class ElementD, + class GmemLayoutTagD, + int AlignmentD, + class Schedule, + class UnusedFusionOp +> +struct [[deprecated("Use TmaWarpSpecialized with fusion::LinCombPerRowBiasEltAct or fusion::LinCombPerRowBiasEltActAux instead")]] +CollectiveBuilder< + arch::Sm90, + OpClass, + TileShape_MNK, + ClusterShape_MNK, + EpilogueTileType, + ElementAccumulator, + ElementCompute, + ElementC_, + GmemLayoutTagC_, + AlignmentC, + ElementD, + GmemLayoutTagD, + AlignmentD, + Schedule, + UnusedFusionOp, + cute::enable_if_t || + cute::is_base_of_v >> { +private: + using EpilogueTile_MN = decltype(detail::sm90_compute_tile_shape_or_override< + ElementD, EpilogueTileType, Schedule, TileShape_MNK>()); + // MSVC doesn't seem to be able to deduce DispatchPolicy correctly if it's + // defined as decltype of a detail::sm90_get_tma_dispatch_policy call. + // Instead, we paste in the contents of that function. A natural refactoring + // would be to create a type alias in the detail namespace. + using DispatchPolicy = Sm90TmaWarpSpecialized< + /* StagesC = */ size(shape_div(take<0, 2>(TileShape_MNK{}), EpilogueTile_MN{})), + /* StagesD = */ 2, + /* FragmentSize = */ size(EpilogueTile_MN{}) / (detail::sm90_is_cooperative_v ? 256 : 128), + /* ReuseSmemC = */ sizeof_bits_v == sizeof_bits_v, + false + >; + + using GmemStrideTypeAux = gemm::TagToStrideC_t; + using SmemLayoutAtomAux = decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom< + GmemStrideTypeAux, typename Schedule::ElementT, EpilogueTile_MN>()); + using SmemCopyOpAux = decltype(detail::sm90_get_smem_store_op_for_accumulator< + GmemStrideTypeAux, typename Schedule::ElementT>()); + using FusionOperationAux = fusion::LinCombPerRowBiasEltActAux< + GmemLayoutTagD, Schedule::template ActivationFunctor, ElementD, ElementCompute, + typename Schedule::ElementT, typename Schedule::ElementBias, ElementC_, ElementCompute + >; + using FusionCallbacksAux = fusion::FusionCallbacks< + DispatchPolicy, FusionOperationAux, TileShape_MNK, EpilogueTile_MN, SmemLayoutAtomAux, SmemCopyOpAux + >; + + using FusionOperationNoAux = fusion::LinCombPerRowBiasEltAct< + Schedule::template ActivationFunctor, ElementD, ElementCompute, + typename Schedule::ElementBias, ElementC_, ElementCompute + >; + using FusionCallbacksNoAux = fusion::FusionCallbacks< + DispatchPolicy, FusionOperationNoAux, TileShape_MNK, EpilogueTile_MN + >; + + using ElementC = cute::conditional_t,ElementD,ElementC_>; // prevents void ref breakages + using GmemLayoutTagC = cute::conditional_t,GmemLayoutTagD,GmemLayoutTagC_>; + + using GmemStrideTypeC = gemm::TagToStrideC_t; + using GmemStrideTypeD = gemm::TagToStrideC_t; + + // Get the smallest tiled copy we can use to retile the accumulators + using CopyAtomC = Copy_Atom; + // Get register to register tiled copy that happen before shared memory store. + // Apply void as no register transform op needed. + using CopyOpR2R = void; + +public: + using CollectiveOp = cutlass::epilogue::collective::Sm90EpilogueTmaWarpSpecializedBiasElementwise< + DispatchPolicy::StagesC, + DispatchPolicy::StagesD, + DispatchPolicy::FragmentSize, + TileShape_MNK, + EpilogueTile_MN, + ElementC_, // Need to pass void through to expose via GemmUniversal + GmemStrideTypeC, + ElementD, + GmemStrideTypeD, + cute::conditional_t, + SM90_TMA_LOAD, + decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom()), + decltype(detail::sm90_get_smem_load_op_for_source()), + SM90_TMA_STORE, + decltype(detail::sm90_get_epilogue_smem_swizzle_layout_atom()), + decltype(detail::sm90_get_smem_store_op_for_accumulator()), + CopyAtomC, + CopyOpR2R + >; +}; + +// CollectiveBuilder that transposed epilogue below is used for sm90 gmma RS TT kernels +// since swapping NNN kernels input matrix and transposing its output at the same time then +// we can get TTN kernel. +template < + class OpClass, + class TileShape_MNK, + class ClusterShape_MNK, + class EpilogueTileType, + class ElementAccumulator, + class ElementCompute, + class ElementC_, + class GmemLayoutTagC_, + int AlignmentC, + class ElementD, + class GmemLayoutTagD, + int AlignmentD, + FloatRoundStyle RoundStyle +> +struct CollectiveBuilder< + arch::Sm90, + OpClass, + TileShape_MNK, + ClusterShape_MNK, + EpilogueTileType, + ElementAccumulator, + ElementCompute, + ElementC_, + GmemLayoutTagC_, + AlignmentC, + ElementD, + GmemLayoutTagD, + AlignmentD, + cutlass::gemm::EpilogueTransposed, + fusion::LinearCombination, + void> { + // Passing void C disables source load + using ElementC = cute::conditional_t, + ElementD, ElementC_>; // prevents cute breakages + using GmemLayoutTagC = cute::conditional_t, + GmemLayoutTagD, GmemLayoutTagC_>; + static constexpr thread::ScaleType::Kind ScaleType = cute::is_void_v ? + thread::ScaleType::OnlyAlphaScaling : thread::ScaleType::Default; + + static constexpr int FragmentSize = 1; + using ThreadOp = thread::LinearCombination< + ElementD, FragmentSize, ElementAccumulator, ElementCompute, + ScaleType, RoundStyle, ElementC>; + + using CollectiveOp = cutlass::epilogue::collective::detail::Sm90TmaWarpSpecializedAdapter< + cutlass::epilogue::collective::DefaultEpilogue< + cutlass::detail::TagToStrideC_t, + cutlass::detail::TagToStrideC_t, + ThreadOp, + cutlass::gemm::EpilogueTransposed> + >; +}; + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::collective diff --git a/include/cutlass/epilogue/collective/builders/sm90_common.inl b/include/cutlass/epilogue/collective/builders/sm90_common.inl new file mode 100644 index 0000000000..cd2639c5dd --- /dev/null +++ b/include/cutlass/epilogue/collective/builders/sm90_common.inl @@ -0,0 +1,80 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +/////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::collective::detail { + +/////////////////////////////////////////////////////////////////////////////// + +// Selects the largest vectorized smem store atom available +template +constexpr auto +sm90_get_smem_store_op_for_accumulator() { + using namespace cute; + + if constexpr (sizeof(ElementD) == 2 && size<0>(GmemStrideTypeD{}) == 1) { + return SM90_U16x8_STSM_T{}; + } + else if constexpr (sizeof(ElementD) == 2 && size<1>(GmemStrideTypeD{}) == 1) { + return SM90_U32x4_STSM_N{}; + } + else { + // auto-vectorizing store + return AutoVectorizingCopyWithAssumedAlignment{}; + } +} + +// Selects the largest vectorized smem load atom available +template +constexpr auto +sm90_get_smem_load_op_for_source() { + using namespace cute; + + // Reuse the logic from smem store selector + using SmemStoreOp = decltype(sm90_get_smem_store_op_for_accumulator()); + + if constexpr (cute::is_same_v) { + return SM75_U16x8_LDSM_T{}; + } + else if constexpr (cute::is_same_v) { + return SM75_U32x4_LDSM_N{}; + } + else { + // auto-vectorizing load + return AutoVectorizingCopyWithAssumedAlignment<128>{}; + } +} + +/////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::collective::detail diff --git a/include/cutlass/epilogue/collective/collective_builder.hpp b/include/cutlass/epilogue/collective/collective_builder.hpp new file mode 100644 index 0000000000..d54cd0a8f7 --- /dev/null +++ b/include/cutlass/epilogue/collective/collective_builder.hpp @@ -0,0 +1,120 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include // cute::DefaultCopy +#include // cute::is_base_of_v + +#include "cutlass/detail/dependent_false.hpp" +#include "cutlass/epilogue/fusion/callbacks.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Used to specify epilogue subtile shape or dispatch to automatic computation of subtile shape +struct EpilogueTileAuto {}; + +// Used to let the builder pick the epilogue schedule automatically. +// Can be overridden with kernel schedule tags in cutlass/gemm/dispatch_policy.hpp +struct EpilogueScheduleAuto {}; +struct EpilogueIm2ColScheduleAuto {}; + +template < + class ArchTag, + class OpClass, + class TileShape_MNK, + class ClusterShape_MNK, + class EpilogueTileType, + class ElementAccumulator, + class ElementCompute, + class ElementC, + class GmemLayoutTagC, + int AlignmentC, + class ElementD, + class GmemLayoutTagD, + int AlignmentD, + class EpilogueScheduleType, + class FusionOpOrCallbacks = cutlass::epilogue::fusion::LinearCombination, + class Enable = void +> +struct CollectiveBuilder { + static_assert(cutlass::detail::dependent_false, + "Could not build a collective epilogue for given parameters."); +}; + +// helper sub-builder for epilogue fusion callbacks (for internal use by CollectiveBuilder only) +namespace detail { + +// callbacks builder with operation tag +template< + class DispatchPolicy, + class FusionOp, + class TileShape_MNK, + class EpilogueTile_MN, + class ElementAccumulator, + class = void +> +struct CallbacksBuilder { + using Callbacks = fusion::FusionCallbacks; +}; + +// callbacks builder with callbacks passthrough +template < + class DispatchPolicy, + class FusionCallbacks, + class TileShape_MNK, + class EpilogueTile_MN, + class ElementAccumulator +> +struct CallbacksBuilder< + DispatchPolicy, + FusionCallbacks, + TileShape_MNK, + EpilogueTile_MN, + ElementAccumulator, + cute::enable_if_t> +> { + using Callbacks = FusionCallbacks; +}; + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "builders/sm90_builder.inl" +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/collective/collective_epilogue.hpp b/include/cutlass/epilogue/collective/collective_epilogue.hpp new file mode 100644 index 0000000000..8fb1a9588b --- /dev/null +++ b/include/cutlass/epilogue/collective/collective_epilogue.hpp @@ -0,0 +1,71 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class DispatchPolicy, + class... Args +> +class CollectiveEpilogue { + static_assert(cutlass::detail::dependent_false, "Could not find an epilogue specialization."); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::collective + +///////////////////////////////////////////////////////////////////////////////////////////////// + +#include "detail.hpp" + +// +// Gemm +// +#include "default_epilogue.hpp" +#include "default_epilogue_array.hpp" +#include "epilogue_tensor_broadcast.hpp" +#include "sm70_epilogue_vectorized.hpp" +#include "sm70_epilogue_vectorized_array.hpp" +#include "sm90_epilogue_tma_warpspecialized.hpp" +#include "sm90_epilogue_tma_warpspecialized_bias_elementwise.hpp" +#include "sm90_epilogue_array_tma_warpspecialized.hpp" +// +// Conv +// +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/collective/default_epilogue.hpp b/include/cutlass/epilogue/collective/default_epilogue.hpp new file mode 100644 index 0000000000..cd4a6ccddb --- /dev/null +++ b/include/cutlass/epilogue/collective/default_epilogue.hpp @@ -0,0 +1,242 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing elementwise operations used by epilogues. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/detail.hpp" + +#include "cute/tensor.hpp" +#include "cute/numeric/numeric_types.hpp" +#include "cutlass/cuda_host_adapter.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Applies an element wise operation to all elements within the fragment +/// and writes them out to destination storage. +template < + class StrideC_, + class StrideD_, + class ThreadEpilogueOp_, + class EpilogueSchedule_ +> +class DefaultEpilogue { +public: + // + // Type Aliases + // + using EpilogueSchedule = EpilogueSchedule_; + using DispatchPolicy = EpilogueSchedule_; + + // derived types of output thread level operator + using ThreadEpilogueOp = ThreadEpilogueOp_; + using ElementOutput = typename ThreadEpilogueOp::ElementOutput; + using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator; + using ElementCompute = typename ThreadEpilogueOp::ElementCompute; + using ElementScalar = ElementCompute; + using ElementC = typename ThreadEpilogueOp::ElementC; + using StrideC = StrideC_; + using ElementD = typename ThreadEpilogueOp::ElementD; + using StrideD = StrideD_; + + using GmemTiledCopyC = void; + using GmemTiledCopyD = void; + + static const int kOutputAlignment = ThreadEpilogueOp::kCount; + using AlignmentType = typename cute::uint_bit::value * kOutputAlignment>::type; + + static_assert(cute::rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static_assert(cute::rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + + struct SharedStorage { }; + + using TensorStorage = SharedStorage; + + // Host side epilogue arguments + struct Arguments { + typename ThreadEpilogueOp::Params thread{}; + ElementC const* ptr_C = nullptr; + StrideC dC{}; + ElementD* ptr_D = nullptr; + StrideD dD{}; + }; + + // Device side epilogue params + using Params = Arguments; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments( + [[maybe_unused]] ProblemShape const& _, + Arguments const& args, + [[maybe_unused]] void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + template + static bool + can_implement( + [[maybe_unused]] ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + return true; + } + + // Note: SharedStorage is unused for DefaultEpilogue + CUTLASS_HOST_DEVICE + DefaultEpilogue(Params const& params_, SharedStorage const& shared_storage = SharedStorage()) + : params(params_), epilogue_op(params_.thread) { } + + CUTLASS_DEVICE + bool + is_source_needed() { + return epilogue_op.is_source_needed(); + } + + template< + class ProblemShapeMNKL, + class BlockShapeMNK, + class BlockCoordMNKL, + class FrgEngine, class FrgLayout, + class TiledMma, + class ResidueMNK + > + CUTLASS_HOST_DEVICE void + operator()( + ProblemShapeMNKL problem_shape_mnkl, + BlockShapeMNK blk_shape_MNK, + BlockCoordMNKL blk_coord_mnkl, + cute::Tensor const& accumulators, + TiledMma tiled_mma, + ResidueMNK residue_mnk, + int thread_idx, + [[maybe_unused]] char* smem_buf) + { + using namespace cute; + using X = Underscore; + + static_assert(cute::rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(is_static::value, "ThreadBlock tile shape must be static"); + static_assert(cute::rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); + static_assert(cute::rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3"); + + // Separate out problem shape for convenience + auto M = get<0>(problem_shape_mnkl); + auto N = get<1>(problem_shape_mnkl); + auto L = get<3>(problem_shape_mnkl); + + auto stride_c = detail::get_epilogue_stride(params.dC); + auto stride_d = detail::get_epilogue_stride(params.dD); + + // Represent the full output tensor + Tensor mC_mnl = make_tensor(make_gmem_ptr(params.ptr_C), make_shape(M,N,L), stride_c); // (m,n,l) + Tensor mD_mnl = make_tensor(make_gmem_ptr(params.ptr_D), make_shape(M,N,L), stride_d); // (m,n,l) + Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + + // Slice to get the tile this CTA is responsible for + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; + Tensor gC = gC_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) + Tensor gD = gD_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) + + // Partition source and destination tiles to match the accumulator partitioning + auto thr_mma = tiled_mma.get_thread_slice(thread_idx); + Tensor tCgD = thr_mma.partition_C(gD); // (VEC,THR_M,THR_N) + Tensor tCgC = thr_mma.partition_C(gC); // (VEC,THR_M,THR_N) + + static_assert(is_static::value, "Accumulator layout must be static"); + CUTE_STATIC_ASSERT_V(size(tCgC) == size(tCgD), + "Source and destination must have the same number of elements."); + CUTE_STATIC_ASSERT_V(size(tCgD) == size(accumulators), + "Accumulator count must have the same destination element count."); + + // Make an identity coordinate tensor for predicating our output MN tile + auto cD = make_identity_tensor(make_shape(unwrap(shape<0>(gD)), unwrap(shape<1>(gD)))); + Tensor tCcD = thr_mma.partition_C(cD); + + // source is needed + if (epilogue_op.is_source_needed()) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accumulators); ++i) { + if (elem_less(tCcD(i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { + tCgD(i) = epilogue_op(accumulators(i), tCgC(i)); + } + } + } + // source is not needed, avoid load + else { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accumulators); ++i) { + if (elem_less(tCcD(i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { + tCgD(i) = epilogue_op(accumulators(i)); + } + } + } + } + +private: + Params params; + ThreadEpilogueOp epilogue_op; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/collective/default_epilogue_array.hpp b/include/cutlass/epilogue/collective/default_epilogue_array.hpp new file mode 100644 index 0000000000..da7562b43a --- /dev/null +++ b/include/cutlass/epilogue/collective/default_epilogue_array.hpp @@ -0,0 +1,283 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing elementwise operations used by epilogues. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/detail.hpp" + +#include "cute/tensor.hpp" +#include "cute/numeric/numeric_types.hpp" +#include "cutlass/trace.h" + +#include "cutlass/cuda_host_adapter.hpp" +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Applies an element wise operation to all elements within the fragment +// and writes them out to destination storage. +template < + class StrideC_, + class StrideD_, + class ThreadEpilogueOp_, + class EpilogueSchedule_ +> +class DefaultEpilogueArray { +public: + // + // Type Aliases + // + using EpilogueSchedule = EpilogueSchedule_; + using DispatchPolicy = EpilogueSchedule_; + + // derived types of output thread level operator + using ThreadEpilogueOp = ThreadEpilogueOp_; + using ElementOutput = typename ThreadEpilogueOp::ElementOutput; + using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator; + using ElementCompute = typename ThreadEpilogueOp::ElementCompute; + using ElementScalar = ElementCompute; + using ElementC = typename ThreadEpilogueOp::ElementC; + using StrideC = StrideC_; + using InternalStrideC = cute::remove_pointer_t; + using ElementD = typename ThreadEpilogueOp::ElementD; + using StrideD = StrideD_; + using InternalStrideD = cute::remove_pointer_t; + + using GmemTiledCopyC = void; + using GmemTiledCopyD = void; + + static const int kOutputAlignment = ThreadEpilogueOp::kCount; + using AlignmentType = typename cute::uint_bit::value * kOutputAlignment>::type; + + static_assert(cute::is_same_v || cute::is_same_v || cute::is_same_v, "Incompatible epilogue schedule."); + static_assert(rank(InternalStrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static_assert(rank(InternalStrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + + struct SharedStorage { }; + + using TensorMapStorage = SharedStorage; + + // Host side epilogue arguments + struct Arguments { + typename ThreadEpilogueOp::Params thread{}; + ElementC const** ptr_C = nullptr; + StrideC dC{}; + ElementD** ptr_D = nullptr; + StrideD dD{}; + }; + + // Device side epilogue params + using Params = Arguments; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments( + ProblemShape const&, + Arguments const& args, + [[maybe_unused]] void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + template + static bool + can_implement( + [[maybe_unused]] ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + return true; + } + + CUTLASS_HOST_DEVICE + DefaultEpilogueArray(Params const& params_) + : params(params_) { } + + CUTLASS_DEVICE + bool + is_source_needed() { + // For Ptr-Array or Grouped Gemm we cannot determine if source is needed based on first beta. + return true; + } + + template< + class ProblemShapeMNKL, + class BlockShapeMNK, + class BlockCoordMNKL, + class FrgEngine, class FrgLayout, + class TiledMma, + class ResidueMNK + > + CUTLASS_HOST_DEVICE void + operator()( + ProblemShapeMNKL problem_shape_mnkl, + BlockShapeMNK blk_shape_MNK, + BlockCoordMNKL blk_coord_mnkl, + cute::Tensor const& accumulators, + TiledMma tiled_mma, + ResidueMNK residue_mnk, + int thread_idx, + [[maybe_unused]] char* smem_buf) + { + using namespace cute; + using X = Underscore; + + static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(is_static::value, "ThreadBlock tile shape must be static"); + static_assert(rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); + static_assert(rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3"); + + // Separate out problem shape for convenience + auto M = get<0>(problem_shape_mnkl); + auto N = get<1>(problem_shape_mnkl); + auto L = get<3>(problem_shape_mnkl); + // Batches are managed by using appropriate pointers to C and D matrices + const int32_t mock_L = 1; + const int32_t mock_l_coord = 0; + // Slice to get the tile this CTA is responsible for + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; + + // If scalar alpha/beta are provided, i.e., same alpha/beta applies to all batches/groups. + // If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups, + // we get the correct alpha/beta values for the current batch/group using group index. + ThreadEpilogueOp epilogue_op = ThreadEpilogueOp(params.thread, l_coord); + + if (epilogue_op.is_source_needed() && params.dC == nullptr) { + // Beta value is non-zero while pointer to C is a nullptr + assert(0); + } + + auto [stride_c, stride_d] = [&, l = l_coord]() { + if constexpr (!cute::is_same_v) { + // If grouped gemm + if (epilogue_op.is_source_needed()) { + return make_tuple( + detail::get_epilogue_stride(params.dC[l]), + detail::get_epilogue_stride(params.dD[l]) + ); + } + else { + return make_tuple( + InternalStrideC{}, + detail::get_epilogue_stride(params.dD[l]) + ); + } + } + else { + return make_tuple( + detail::get_epilogue_stride(params.dC), + detail::get_epilogue_stride(params.dD) + ); + } + }(); + + // Represent the full output tensor + ElementC const* ptr_C_l = nullptr; + if (epilogue_op.is_source_needed()) { + ptr_C_l = params.ptr_C[l_coord]; + } + Tensor mC_mnl = make_tensor(make_gmem_ptr(ptr_C_l), make_shape(M,N,mock_L), stride_c); // (m,n,l) + Tensor mD_mnl = make_tensor(make_gmem_ptr(params.ptr_D[l_coord]), make_shape(M,N,mock_L), stride_d); // (m,n,l) + Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + + Tensor gC = gC_mnl(_,_,m_coord,n_coord, mock_l_coord); // (BLK_M,BLK_N) + Tensor gD = gD_mnl(_,_,m_coord,n_coord, mock_l_coord); // (BLK_M,BLK_N) + + // Partition source and destination tiles to match the accumulator partitioning + auto thr_mma = tiled_mma.get_thread_slice(thread_idx); + Tensor tCgD = thr_mma.partition_C(gD); // (VEC,THR_M,THR_N) + Tensor tCgC = thr_mma.partition_C(gC); // (VEC,THR_M,THR_N) + + static_assert(is_static::value, "Accumulator layout must be static"); + CUTE_STATIC_ASSERT_V(size(tCgC) == size(tCgD), + "Source and destination must have the same number of elements."); + CUTE_STATIC_ASSERT_V(size(tCgD) == size(accumulators), + "Accumulator count must have the same destination element count."); + + // Make an identity coordinate tensor for predicating our output MN tile + auto cD = make_identity_tensor(make_shape(unwrap(shape<0>(gD)), unwrap(shape<1>(gD)))); + Tensor tCcD = thr_mma.partition_C(cD); + + // source is needed + if (epilogue_op.is_source_needed()) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accumulators); ++i) { + if (elem_less(tCcD(i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { + tCgD(i) = epilogue_op(accumulators(i), tCgC(i)); + } + } + } + // source is not needed, avoid load + else { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accumulators); ++i) { + if (elem_less(tCcD(i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { + tCgD(i) = epilogue_op(accumulators(i)); + } + } + } + } + +private: + Params params; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/collective/detail.hpp b/include/cutlass/epilogue/collective/detail.hpp new file mode 100644 index 0000000000..23e57d99b8 --- /dev/null +++ b/include/cutlass/epilogue/collective/detail.hpp @@ -0,0 +1,502 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/pipeline/pipeline.hpp" +#include "cutlass/gemm/gemm.h" +#include "cutlass/gemm/dispatch_policy.hpp" +#include "cutlass/epilogue/dispatch_policy.hpp" + +#include "cute/tensor.hpp" +#include "cute/numeric/numeric_types.hpp" +#include "cute/util/type_traits.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace collective { + +namespace detail { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +constexpr bool +is_m_major() { + return cutlass::gemm::detail::is_major<0,Stride>(); +} + +template +constexpr bool +is_n_major() { + return cutlass::gemm::detail::is_major<1,Stride>(); +} + +template +constexpr bool +is_im2col() { + return cute::is_same_v> + || cute::is_same_v> + || cute::is_same_v>; +} + +template +struct sm90_is_ptr_array_tma : cute::false_type {}; + +template<> +struct sm90_is_ptr_array_tma : cute::true_type {}; + +template<> +struct sm90_is_ptr_array_tma : cute::true_type {}; + +template<> +struct sm90_is_ptr_array_tma : cute::true_type {}; + +template +static constexpr bool sm90_is_ptr_array_tma_v = sm90_is_ptr_array_tma::value; + +template +struct sm90_is_ptr_array_tma_cooperative : cute::false_type {}; + +template<> +struct sm90_is_ptr_array_tma_cooperative : cute::true_type {}; + +template +static constexpr bool sm90_is_ptr_array_tma_cooperative_v = sm90_is_ptr_array_tma_cooperative::value; + +template +struct sm90_is_ptr_array_tma_pingpong : cute::false_type {}; + +template<> +struct sm90_is_ptr_array_tma_pingpong : cute::true_type {}; + +template +static constexpr bool sm90_is_ptr_array_tma_pingpong_v = sm90_is_ptr_array_tma_pingpong::value; + +template +struct sm90_is_ptr_array_tma_dispatch_policy : cute::false_type {}; + +template< + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + int NumEpilogueWarpGroups +> +struct sm90_is_ptr_array_tma_dispatch_policy< + Sm90PtrArrayTmaWarpSpecialized> + : cute::true_type {}; + +template +static constexpr bool sm90_is_ptr_array_tma_dispatch_policy_v = sm90_is_ptr_array_tma_dispatch_policy::value; + +using cutlass::atomic_maximum; + +template +static constexpr int elements_per_access_v = cutlass::sizeof_bits::value / cutlass::sizeof_bits::value; + +template +static constexpr bool sm90_is_cooperative_v = + cute::is_base_of_v || + sm90_is_ptr_array_tma_cooperative_v; + +template +static constexpr bool sm90_is_warp_specialized_v = + (!sm90_is_ptr_array_tma_cooperative_v && sm90_is_ptr_array_tma_v) || + cute::is_base_of_v; + +template +static constexpr bool is_im2col_mode = + cute::is_same_v || + cute::is_same_v || + cute::is_same_v; + +template +struct EmptyStorage { + CUTLASS_HOST_DEVICE + T* data() { return nullptr; } +}; + +template +CUTLASS_HOST_DEVICE +auto get_epilogue_stride(Stride stride){ + if constexpr (cute::is_base_of_v|| + cute::is_base_of_v) { + return cute::make_stride(cute::get<1>(stride), cute::get<0>(stride), cute::get<2>(stride)); + } + else { + return stride; + } +} + +template +struct IsThreadEpilogueOpWithBias { + static constexpr bool value = false; + using type = typename ThreadEpilogueOp::ElementCompute; +}; + +template +struct IsThreadEpilogueOpWithBias > { + static constexpr bool value = true; + using type = typename ThreadEpilogueOp::ElementBias; +}; + +template +struct IsThreadEpilogueOpWithPerChannelScaling { + static constexpr bool value = false; +}; + +template +struct IsThreadEpilogueOpWithPerChannelScaling > { + static constexpr bool value = true; +}; + +template +struct IsThreadEpilogueOpWithActivation { + static constexpr bool value = false; + using type = void; +}; + +template +struct IsThreadEpilogueOpWithActivation > { + static constexpr bool value = true; + using type = typename ThreadEpilogueOp::ActivationFn; +}; + +template +struct IsThreadEpilogueOpWithElementwiseArguments : cute::false_type {}; + +template +struct IsThreadEpilogueOpWithElementwiseArguments< + ThreadEpilogueOp, + cute::void_t> : cute::true_type {}; + +// Wrapper class to use operator-style epilogues in sm90 TMA warp-specialized kernels +template +class Sm90TmaWarpSpecializedAdapter : public EpilogueOp { +public: + using GmemTiledCopyC = void; + using GmemTiledCopyD = void; + + using LoadPipeline = cutlass::PipelineTransactionAsync<0>; + using LoadPipelineState = cutlass::PipelineState<0>; + constexpr static uint32_t TmaTransactionBytes = 0; + constexpr static bool RequiresTransactionBytes = false; + + using StorePipeline = cutlass::PipelineTmaStore<0>; + using StorePipelineState = cutlass::PipelineState<0>; + + using TensorStorage = typename EpilogueOp::SharedStorage; + using TensorMapStorage = typename EpilogueOp::SharedStorage; + using PipelineStorage = typename LoadPipeline::SharedStorage; + + template + CUTLASS_HOST_DEVICE + static constexpr int + get_load_pipe_increment(CtaTileMNK) { + return 1; + } + + template + CUTLASS_HOST_DEVICE + static constexpr int + get_store_pipe_increment(CtaTileMNK) { + return 1; + } + + CUTLASS_DEVICE + static void prefetch_tma_descriptors([[maybe_unused]] typename EpilogueOp::Params const&) { + } + + // ctor inheritance + using EpilogueOp::EpilogueOp; + + CUTLASS_HOST_DEVICE + Sm90TmaWarpSpecializedAdapter( + typename EpilogueOp::Params const& params, + [[maybe_unused]] TensorStorage& shared_tensors) + : EpilogueOp(params) { } + + CUTLASS_DEVICE + bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE auto + load_init( + [[maybe_unused]] typename EpilogueOp::Params const& params, + [[maybe_unused]] TensorMapStorage& shared_tensormaps, + [[maybe_unused]] int32_t sm_count, + [[maybe_unused]] int32_t sm_idx) { + return cute::make_tuple(nullptr); + } + + template< + class ProblemShapeMNKL, + class CtaTileMNK, + class CtaCoordMNKL, + class TiledMma + > + CUTLASS_DEVICE auto + load( + [[maybe_unused]] LoadPipeline load_pipeline, + LoadPipelineState load_pipe_producer_state, + [[maybe_unused]] ProblemShapeMNKL problem_shape_mnkl, + [[maybe_unused]] CtaTileMNK cta_tile_mnk, + [[maybe_unused]] CtaCoordMNKL cta_coord_mnkl, + [[maybe_unused]] TiledMma tiled_mma, + [[maybe_unused]] int thread_idx, + [[maybe_unused]] TensorStorage& shared_tensors, + [[maybe_unused]] int subtile_idx=-1) + { + return load_pipe_producer_state; + } + + template< + class ProblemShapeMNKL, + class TileShapeMNK, + class TileCoordMNKL, + class TiledMma, + class TensorMapC + > + CUTLASS_DEVICE auto + load( + [[maybe_unused]] LoadPipeline load_pipeline, + LoadPipelineState load_pipe_producer_state, + [[maybe_unused]] ProblemShapeMNKL problem_shape_mnkl, + [[maybe_unused]] TileShapeMNK tile_shape_MNK, + [[maybe_unused]] TileCoordMNKL tile_coord_mnkl, + [[maybe_unused]] TiledMma tiled_mma, + [[maybe_unused]] int thread_idx, + [[maybe_unused]] TensorStorage& shared_tensors, + [[maybe_unused]] TensorMapC const& load_tensormap, + [[maybe_unused]] int subtile_idx=-1, + [[maybe_unused]] bool wait = false) + { + return load_pipe_producer_state; + } + + CUTLASS_DEVICE auto + load_tail( + [[maybe_unused]] LoadPipeline load_pipeline, + LoadPipelineState load_pipe_producer_state) + { + return load_pipe_producer_state; + } + + CUTLASS_DEVICE auto + store_init( + [[maybe_unused]] typename EpilogueOp::Params const& params, + [[maybe_unused]] TensorMapStorage& shared_tensormaps, + [[maybe_unused]] int32_t sm_count, + [[maybe_unused]] int32_t sm_idx, + [[maybe_unused]] int32_t warp_group_idx) { + return cute::make_tuple(nullptr); + } + + template< + class ProblemShapeMNKL, + class CtaTileMNK, + class CtaCoordMNKL, + class AccEngine, class AccLayout, + class TiledMma + > + CUTLASS_DEVICE auto + store( + [[maybe_unused]] LoadPipeline load_pipeline, + LoadPipelineState load_pipe_consumer_state, + [[maybe_unused]] StorePipeline store_pipeline, + StorePipelineState store_pipe_producer_state, + ProblemShapeMNKL problem_shape_mnkl, + CtaTileMNK cta_tile_mnk, + CtaCoordMNKL cta_coord_mnkl, + cute::Tensor accumulators, + TiledMma tiled_mma, + int thread_idx, + TensorStorage& shared_tensors, + int subtile_index = -1) + { + constexpr int BLK_M_RANK = cute::rank<0>(cta_tile_mnk); + auto m_max_coord = unwrap(cute::transform(make_seq{}, [&](auto i) { + return get<0,i>(problem_shape_mnkl) - get<0,i>(cta_tile_mnk) * get<0,i>(cta_coord_mnkl); + })); + + constexpr int BLK_N_RANK = cute::rank<1>(cta_tile_mnk); + auto n_max_coord = unwrap(cute::transform(make_seq{}, [&](auto i) { + return get<1,i>(problem_shape_mnkl) - get<1,i>(cta_tile_mnk) * get<1,i>(cta_coord_mnkl); + })); + + auto residue_mnk = make_tuple(m_max_coord, n_max_coord, Int<0>{}); + + (*this)( + problem_shape_mnkl, + cta_tile_mnk, + cta_coord_mnkl, + accumulators, + tiled_mma, + residue_mnk, + thread_idx, + reinterpret_cast(&shared_tensors)); + + return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state); + } + + template< + class ProblemShapeMNKL, + class TileShapeMNK, + class TileCoordMNKL, + class AccEngine, class AccLayout, + class TiledMma, + class TensorMapD + > + CUTLASS_DEVICE auto + store( + [[maybe_unused]] LoadPipeline load_pipeline, + LoadPipelineState load_pipe_consumer_state, + [[maybe_unused]] StorePipeline store_pipeline, + StorePipelineState store_pipe_producer_state, + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_MNK, + TileCoordMNKL tile_coord_mnkl, + cute::Tensor accumulators, + TiledMma tiled_mma, + int thread_idx, + TensorStorage& shared_tensors, + [[maybe_unused]] TensorMapD const& store_tensormap, + int subtile_index = -1) + { + constexpr int BLK_M_RANK = cute::rank<0>(tile_shape_MNK); + auto m_max_coord = unwrap(cute::transform(make_seq{}, [&](auto i) { + return get<0,i>(problem_shape_mnkl) - get<0,i>(tile_shape_MNK) * get<0,i>(tile_coord_mnkl); + })); + + constexpr int BLK_N_RANK = cute::rank<1>(tile_shape_MNK); + auto n_max_coord = unwrap(cute::transform(make_seq{}, [&](auto i) { + return get<1,i>(problem_shape_mnkl) - get<1,i>(tile_shape_MNK) * get<1,i>(tile_coord_mnkl); + })); + + auto residue_mnk = make_tuple(m_max_coord, n_max_coord, Int<0>{}); + + (*this)( + problem_shape_mnkl, + tile_shape_MNK, + tile_coord_mnkl, + accumulators, + tiled_mma, + residue_mnk, + thread_idx, + reinterpret_cast(&shared_tensors)); + + return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state); + } + + CUTLASS_DEVICE auto + store_tail( + [[maybe_unused]] LoadPipeline load_pipeline, + LoadPipelineState load_pipe_consumer_state, + [[maybe_unused]] StorePipeline store_pipeline, + StorePipelineState store_pipe_producer_state) { + return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state); + } + + // Dummy methods to perform different parts of TMA/Tensormap modifications + + template + CUTLASS_DEVICE + void + tensormaps_perform_update( + [[maybe_unused]] TensorMapStorage& shared_tensormaps, + [[maybe_unused]] typename EpilogueOp::Params const& params, + [[maybe_unused]] cute::TmaDescriptor const* tensormap, + [[maybe_unused]] ProblemShapeMNKL problem_shape, + [[maybe_unused]] int32_t next_batch, + [[maybe_unused]] int32_t warp_group_idx) { } + + template + CUTLASS_DEVICE + void + tensormaps_cp_fence_release( + [[maybe_unused]] TensorMapStorage& shared_tensormaps, + [[maybe_unused]] cute::TmaDescriptor const* tensormap, + [[maybe_unused]] int32_t warp_group_idx) { } + + template + CUTLASS_DEVICE + void + tensormaps_fence_acquire([[maybe_unused]] cute::TmaDescriptor const* tensormap) { } +}; + +// SFINAE helpers for detecting beta/beta_ptr/beta_ptr_array in EVT arguments. +template +struct has_beta { + static constexpr bool value = false; +}; + +template +struct has_beta> { + static constexpr bool value = true; +}; + +template +struct has_beta_ptr { + static constexpr bool value = false; +}; + +template +struct has_beta_ptr> { + static constexpr bool value = true; +}; + +template +struct has_beta_ptr_array { + static constexpr bool value = false; +}; + +template +struct has_beta_ptr_array> { + static constexpr bool value = true; +}; + +} // namespace detail +} // namespace collective +} // namespace epilogue +} // namespace cutlass diff --git a/include/cutlass/epilogue/collective/epilogue_tensor_broadcast.hpp b/include/cutlass/epilogue/collective/epilogue_tensor_broadcast.hpp new file mode 100644 index 0000000000..48833ecf10 --- /dev/null +++ b/include/cutlass/epilogue/collective/epilogue_tensor_broadcast.hpp @@ -0,0 +1,271 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Functor for performing tensor-tensor broadacasts atop existing epilogues. + + Concretely, the opeartion performed is the following: + UnaryOp( + BinaryOp1( + BinaryOp0( + Activation((alpha * A @ B) + bias), + beta * C0 + ), + beta * C1 + ) + ) + + where: + - C0 and C1 have the same extents as the output + - BinaryOp0 and BinaryOp1 perform elementwise binary operations + - UnaryOp is an elementwise operation +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/epilogue/collective/detail.hpp" + +#include "cute/tensor.hpp" +#include "cutlass/cuda_host_adapter.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace collective { +///////////////////////////////////////////////////////////////////////////////////////////////// + +/// Collective epilogue that applies elementwise tensor-tensor operations atop other epilogues +/// +template < + class StrideC_, + class StrideD_, + class ThreadEpilogueOp_, + class EpilogueSchedule_, + bool PerColumnBias_ = false +> +class EpilogueTensorBroadcast { +public: + // + // Type Aliases + // + using EpilogueSchedule = EpilogueSchedule_; + + // derived types of output thread level operator + using ThreadEpilogueOp = ThreadEpilogueOp_; + using ElementOutput = typename ThreadEpilogueOp::ElementOutput; + using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator; + using ElementCompute = typename ThreadEpilogueOp::ElementCompute; + using ElementScalar = ElementCompute; + using ElementBias = typename ThreadEpilogueOp::ElementBias; + using ElementC = typename ThreadEpilogueOp::ElementC; + using StrideC = StrideC_; + using ElementD = typename ThreadEpilogueOp::ElementD; + using StrideD = StrideD_; + using ActivationFunctor = typename ThreadEpilogueOp::ActivationFunctor; + + static_assert(cute::rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static_assert(cute::rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + + static constexpr int kOutputAlignment = ThreadEpilogueOp::kCount; + using AlignmentType = typename cute::uint_bit::value * kOutputAlignment>::type; + + static constexpr bool IsBinaryOp0Enabled = ThreadEpilogueOp::IsBinaryOp0Enabled; + static constexpr bool IsBinaryOp1Enabled = ThreadEpilogueOp::IsBinaryOp1Enabled; + static constexpr bool IsUnaryOpEnabled = ThreadEpilogueOp::IsUnaryOpEnabled; + + static constexpr bool PerColumnBias = PerColumnBias_; + using BiasStride = typename cute::conditional_t, Stride<_1, _0, _0>>; + + struct SharedStorage { }; + + // Host side epilogue arguments + struct Arguments { + typename ThreadEpilogueOp::Params thread{}; + StrideC dC{}; + ElementD* ptr_D = nullptr; + StrideD dD{}; + ElementBias* ptr_Bias = nullptr; + ElementC* ptr_C0 = nullptr; + ElementC* ptr_C1 = nullptr; + }; + + // Device side epilogue params + using Params = Arguments; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments( + [[maybe_unused]] ProblemShape const& _, + Arguments const& args, + [[maybe_unused]] void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + template + static bool + can_implement( + [[maybe_unused]] ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + return true; + } + + CUTLASS_HOST_DEVICE + EpilogueTensorBroadcast(Params const& params_) + : params(params_), epilogue_op(params_.thread) { } + + CUTLASS_DEVICE + bool + is_source_needed() { + return epilogue_op.is_source0_needed() || epilogue_op.is_source1_needed(); + } + + template< + class ProblemShapeMNKL, + class BlockShapeMNK, + class BlockCoordMNKL, + class FrgEngine, class FrgLayout, + class TiledMma, + class ResidueMNK + > + CUTLASS_HOST_DEVICE void + operator()( + ProblemShapeMNKL problem_shape_mnkl, + BlockShapeMNK blk_shape_MNK, + BlockCoordMNKL blk_coord_mnkl, + cute::Tensor const& accumulators, + TiledMma tiled_mma, + ResidueMNK residue_mnk, + int thread_idx, + [[maybe_unused]] char* smem_buf) + { + using namespace cute; + using X = Underscore; + + static_assert(cute::rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(is_static::value, "ThreadBlock tile shape must be static"); + static_assert(cute::rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); + static_assert(cute::rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 4"); + + // Separate out problem shape for convenience + auto M = get<0>(problem_shape_mnkl); + auto N = get<1>(problem_shape_mnkl); + auto L = get<3>(problem_shape_mnkl); + + auto stride_c = detail::get_epilogue_stride(params.dC); + auto stride_d = detail::get_epilogue_stride(params.dD); + auto stride_bias = detail::get_epilogue_stride(BiasStride{}); + + // Represent the full output tensor + Tensor mC0_mnl = make_tensor(make_gmem_ptr(params.ptr_C0), make_shape(M,N,L), stride_c); // (m,n,l) + Tensor mC1_mnl = make_tensor(make_gmem_ptr(params.ptr_C1), make_shape(M,N,L), stride_c); // (m,n,l) + Tensor mD_mnl = make_tensor(make_gmem_ptr(params.ptr_D), make_shape(M,N,L), stride_d); // (m,n,l) + Tensor mBias_mnl = make_tensor(make_gmem_ptr(params.ptr_Bias), make_shape(M,N,L), stride_bias); // (m,n,l) + + Tensor gC0_mnl = local_tile(mC0_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + Tensor gC1_mnl = local_tile(mC1_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + + Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + Tensor gBias_mnl = local_tile(mBias_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + + // Slice to get the tile this thread block is responsible for + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; + Tensor gC0 = gC0_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) + Tensor gC1 = gC1_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) + Tensor gD = gD_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) + Tensor gBias = gBias_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) + + // Partition source and destination tiles to match the accumulator partitioning + auto thr_mma = tiled_mma.get_thread_slice(thread_idx); + Tensor tCgD = thr_mma.partition_C(gD); // (VEC,THR_M,THR_N) + Tensor tCgC0 = thr_mma.partition_C(gC0); // (VEC,THR_M,THR_N) + Tensor tCgC1 = thr_mma.partition_C(gC1); // (VEC,THR_M,THR_N) + Tensor tCgBias = thr_mma.partition_C(gBias); // (VEC,THR_M,THR_N) + + static_assert(is_static::value, + "Accumulator layout must be static"); + CUTE_STATIC_ASSERT_V(size(tCgC0) == size(tCgD), + "Source and destination must have the same number of elements."); + CUTE_STATIC_ASSERT_V(size(tCgC1) == size(tCgD), + "Source and destination must have the same number of elements."); + CUTE_STATIC_ASSERT_V(size(tCgD) == size(accumulators), + "Accumulator count must have the same destination element count."); + CUTE_STATIC_ASSERT_V(size(tCgBias) == size(accumulators), + "Accumulator count must have the same destination element count."); + + auto cD = make_identity_tensor(make_shape(unwrap(shape<0>(gD)), unwrap(shape<1>(gD)))); + Tensor tCcD = thr_mma.partition_C(cD); + + bool bias_needed = params.ptr_Bias != nullptr; + bool c0_needed = (params.ptr_C0 != nullptr) && epilogue_op.is_source0_needed(); + bool c1_needed = (params.ptr_C1 != nullptr) && epilogue_op.is_source1_needed(); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(accumulators); ++i) { + if (elem_less(tCcD(i), make_coord(get<0>(residue_mnk), get<1>(residue_mnk)))) { + ElementBias bias = bias_needed ? tCgBias(i) : ElementBias(0); + ElementC c0 = c0_needed ? tCgC0(i) : ElementC(0); + ElementC c1 = c1_needed ? tCgC1(i) : ElementC(0); + + tCgD(i) = epilogue_op(accumulators(i), c0, c1, bias); + } + } + } + +private: + Params params; + ThreadEpilogueOp epilogue_op; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp b/include/cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp new file mode 100644 index 0000000000..a8083dab1d --- /dev/null +++ b/include/cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp @@ -0,0 +1,549 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing elementwise operations used by epilogues. +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + class StrideC, + class StrideD, + class ThreadEpilogueOp, + class SmemLayout, + class CopyAtomR2S, + class TiledCopyS2R, + class CopyAtomR2G, + class EpilogueScheduleType = EpilogueSimtVectorized, + class Enable = void +> +class Epilogue { + static_assert(cute::is_same_v || + cute::is_same_v, + "Could not find an epilogue specialization."); +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Epilogue Vectorized +/// Applies an element wise operation to all elements within the fragment +/// and writes it out to destination storage. +/// +/// Ways to generalize this: +/// - CTA tile shape +/// - vectorization requirements (GMEM) +/// - vectoriz(able) transform() +/// +template < + class StrideC_, + class StrideD_, + class ThreadEpilogueOp_, + class SmemLayout_, + class CopyAtomR2S_, + class TiledCopyS2R_, + class CopyAtomR2G_, + class EpilogueScheduleType_ +> +class Epilogue< + StrideC_, + StrideD_, + ThreadEpilogueOp_, + SmemLayout_, + CopyAtomR2S_, + TiledCopyS2R_, + CopyAtomR2G_, + EpilogueScheduleType_, + cute::enable_if_t< + cute::is_same_v + > + > { +public: + // + // Type Aliases + // + // derived types of output thread level operator + using ThreadEpilogueOp = ThreadEpilogueOp_; + using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator; + using ElementCompute = typename ThreadEpilogueOp::ElementCompute; + using ElementScalar = ElementCompute; + using ElementOutput = typename ThreadEpilogueOp::ElementOutput; + using ElementC = typename ThreadEpilogueOp::ElementC; + using StrideC = StrideC_; + using ElementD = typename ThreadEpilogueOp::ElementD; + using StrideD = StrideD_; + using ElementBias = typename detail::IsThreadEpilogueOpWithBias::type; + using SmemLayout = SmemLayout_; + using CopyAtomR2S = CopyAtomR2S_; + using TiledCopyS2R = TiledCopyS2R_; + using CopyAtomR2G = CopyAtomR2G_; + + using GmemTiledCopyC = void; + using GmemTiledCopyD = CopyAtomR2G; + + static constexpr bool IsEpilogueBiasSupported = detail::IsThreadEpilogueOpWithBias::value; + using StrideBias = cute::conditional_t(), Stride<_1,_0,int64_t>, Stride<_0,_1,int64_t>>; + + static_assert(cute::rank(StrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static_assert(cute::rank(StrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + + struct SharedStorage + { + cute::array_aligned> smem_epilogue; + }; + + static constexpr bool IsActHasArgs = detail::IsThreadEpilogueOpWithElementwiseArguments::value; + + // Host side epilogue arguments + template + struct ThreadEpilogueOpArguments { + ElementScalar alpha{0}; + ElementScalar beta{0}; + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias{}; + }; + + template + struct ThreadEpilogueOpArguments< + ThreadEpiOp, + cute::enable_if_t::value>> { + ElementScalar alpha{0}; + ElementScalar beta{0}; + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias{}; + typename ThreadEpiOp::ElementwiseArguments activation{}; + }; + + struct Arguments { + ThreadEpilogueOpArguments thread{}; + using StrideBias = decltype(thread.dBias); + ElementC const* ptr_C = nullptr; + StrideC dC{}; + ElementD* ptr_D = nullptr; + StrideD dD{}; + }; + + // Device side epilogue params + template + struct ParamsType { + typename ThreadEpiOp::Params thread{}; + ElementC const* ptr_C = nullptr; + StrideC dC{}; + ElementD* ptr_D = nullptr; + StrideD dD{}; + ElementBias const* ptr_Bias = nullptr; + StrideBias dBias{}; + }; + + template + struct ParamsType< + ThreadEpiOp, + cute::enable_if_t::value>> { + typename ThreadEpiOp::Params thread{}; + typename ThreadEpiOp::ElementwiseArguments activation{}; + ElementC const* ptr_C = nullptr; + StrideC dC{}; + ElementD* ptr_D = nullptr; + StrideD dD{}; + ElementBias const* ptr_Bias = nullptr; + StrideBias dBias{}; + }; + + using Params = ParamsType; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments( + [[maybe_unused]] ProblemShape const& _, + Arguments const& args, + [[maybe_unused]] void* workspace) { + typename ThreadEpilogueOp::Params thread_op_args; + thread_op_args.alpha = args.thread.alpha; + thread_op_args.beta = args.thread.beta; + thread_op_args.alpha_ptr = args.thread.alpha_ptr; + thread_op_args.beta_ptr = args.thread.beta_ptr; + + if constexpr (IsActHasArgs) { + return { + thread_op_args, + args.thread.activation, + args.ptr_C, + args.dC, + args.ptr_D, + args.dD, + args.thread.bias_ptr, + args.thread.dBias + }; + } + else { + return { + thread_op_args, + args.ptr_C, + args.dC, + args.ptr_D, + args.dD, + args.thread.bias_ptr, + args.thread.dBias + }; + } + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + template + static bool + can_implement( + [[maybe_unused]] ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + return true; + } + + CUTLASS_HOST_DEVICE + Epilogue(Params const& params_) + : params(params_), epilogue_op(params_.thread) { } + + CUTLASS_DEVICE + bool + is_source_needed() { + return epilogue_op.is_source_needed(); + } + + template< + class ProblemShapeMNKL, + class BlockShapeMNK, + class BlockCoordMNKL, + class FrgEngine, class FrgLayout, + class TiledMma, + class ResidueMNK + > + CUTLASS_DEVICE void + operator()( + ProblemShapeMNKL problem_shape_mnkl, + BlockShapeMNK blk_shape_MNK, + BlockCoordMNKL blk_coord_mnkl, + cute::Tensor const& accumulators, // (MMA,MMA_M,MMA_N) + TiledMma tiled_mma, + ResidueMNK residue_mnk, + int thread_idx, + char* smem_buf) { + using namespace cute; + using X = Underscore; + + static_assert(cute::rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(is_static::value, "ThreadBlock tile shape must be static"); + static_assert(cute::rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); + static_assert(cute::rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3"); + + // synchronizing function for smem reads/writes +#if CUDA_BARRIER_ENABLED + auto synchronize = [] () { cutlass::arch::NamedBarrier::sync(typename TiledCopyS2R::TiledNumThr{}, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; +#else + auto synchronize = [] () { __syncthreads(); }; +#endif + + // Separate out problem shape for convenience + auto M = get<0>(problem_shape_mnkl); + auto N = get<1>(problem_shape_mnkl); + auto L = get<3>(problem_shape_mnkl); + + // Represent the full output tensor + Tensor mC_mnl = make_tensor(make_gmem_ptr(params.ptr_C), make_shape(M,N,L), params.dC); // (m,n,l) + Tensor mD_mnl = make_tensor(make_gmem_ptr(params.ptr_D), make_shape(M,N,L), params.dD); // (m,n,l) + Tensor mBias_mnl = make_tensor(make_gmem_ptr(params.ptr_Bias), make_shape(M,N,L), params.dBias); // (m,n,l) + + Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + Tensor gBias_mnl = local_tile(mBias_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + + // Slice to get the tile this CTA is responsible for + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; + Tensor gC = gC_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) + Tensor gD = gD_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) + Tensor gBias = gBias_mnl(_,_,m_coord,n_coord,l_coord); // (BLK_M,BLK_N) + + // Construct a tensor in SMEM that we can partition for rearranging data + SharedStorage& storage = *reinterpret_cast(smem_buf); + Tensor sAcc = make_tensor(make_smem_ptr(storage.smem_epilogue.data()), SmemLayout{}); // (SMEM_M,SMEM_N) + + // Partition sAcc to match the accumulator partitioning + auto tiled_r2s = make_tiled_copy_C(CopyAtomR2S{}, tiled_mma); + auto thread_r2s = tiled_r2s.get_thread_slice(thread_idx); + Tensor tRS_rAcc = thread_r2s.retile_S(accumulators); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor tRS_sAcc = thread_r2s.partition_D(sAcc); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // Tile gD and gC by the shape of SmemLayout first + auto tile = make_shape(size<0>(sAcc), size<1>(sAcc)); + Tensor gCt = flat_divide(gC, tile); // (SMEM_M,SMEM_N,TILE_M,TILE_N) + Tensor gDt = flat_divide(gD, tile); // (SMEM_M,SMEM_N,TILE_M,TILE_N) + Tensor gBiast = flat_divide(gBias, tile); // (SMEM_M,SMEM_N,TILE_M,TILE_N) + + // Partition sAcc, gC, and gD for the output + auto tiled_s2r = TiledCopyS2R{}; + auto thread_s2r = tiled_s2r.get_thread_slice(thread_idx); + Tensor tSR_sAcc = thread_s2r.partition_S(sAcc); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tSR_gC = thread_s2r.partition_D(gCt); // ((Atom,AtomNum),ATOM_M,ATOM_N,TILE_M,TILE_N) + Tensor tSR_gD = thread_s2r.partition_D(gDt); // ((Atom,AtomNum),ATOM_M,ATOM_N,TILE_M,TILE_N) + Tensor tSR_gBias = thread_s2r.partition_D(gBiast); // ((Atom,AtomNum),ATOM_M,ATOM_N,TILE_M,TILE_N) + + // Allocate intermediate registers on the dst tensors + Tensor tSR_rAcc = make_tensor(take<0,3>(shape(tSR_gC))); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tSR_rC = make_tensor(shape(tSR_rAcc)); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tSR_rD = make_tensor(shape(tSR_rAcc)); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tSR_rBias = make_tensor_like(tSR_gBias); // ((Atom,AtomNum),ATOM_M,ATOM_N,TILE_M,TILE_N) + + // Repeat the D-partitioning for coordinates and predication + Tensor cD = make_identity_tensor(make_shape(size<0>(gD),size<1>(gD))); // (BLK_M,BLK_N) -> (blk_m,blk_n) + Tensor cDt = flat_divide(cD, tile); // (SMEM_M,SMEM_N,TILE_M,TILE_N) + Tensor tSR_cD = thread_s2r.partition_D(cDt); // ((Atom,AtomNum),ATOM_M,ATOM_N,TILE_M,TILE_N) + + CUTE_STATIC_ASSERT(size<1>(tRS_rAcc) % size<3>(tSR_gC) == 0); // TILE_M divides MMA_M + CUTE_STATIC_ASSERT(size<2>(tRS_rAcc) % size<4>(tSR_gC) == 0); // TILE_N divides MMA_N + +#if 0 + if (thread_idx == 0 && m_coord == 0 && n_coord == 0) { + print("aC : "); print(accumulators.layout()); print("\n"); + print("gC : "); print(gC.layout()); print("\n"); + print("gD : "); print(gD.layout()); print("\n"); + print("gBias : "); print(gBias.layout()); print("\n"); + print("sAcc : "); print(sAcc.layout()); print("\n"); + print("\n"); + print("tRS_sAcc : "); print(tRS_sAcc.layout()); print("\n"); + print("tRS_rAcc : "); print(tRS_rAcc.layout()); print("\n"); + print("\n"); + print("gDt : "); print(gDt.layout()); print("\n"); + print("tSR_sAcc : "); print(tSR_sAcc.layout()); print("\n"); + print("tSR_rAcc : "); print(tSR_rAcc.layout()); print("\n"); + print("\n"); + print("tSR_rC : "); print(tSR_rC.layout()); print("\n"); + print("tSR_rD : "); print(tSR_rD.layout()); print("\n"); + print("tSR_gC : "); print(tSR_gC.layout()); print("\n"); + print("tSR_gD : "); print(tSR_gD.layout()); print("\n"); + print("\n"); + print("gBiast : "); print(gBiast.layout()); print("\n"); + print("tSR_gBias : "); print(tSR_gBias.layout()); print("\n"); + print("tSR_rBias : "); print(tSR_rBias.layout()); print("\n"); + } +#endif + + if constexpr (IsEpilogueBiasSupported) { + if (params.ptr_Bias) { + // Filter so we don't issue redundant copies over stride-0 modes + // (only works if 0-strides are in same location, which is by construction) + Tensor tSR_gBias_flt = filter_zeros(tSR_gBias); + Tensor tSR_rBias_flt = filter_zeros(tSR_rBias); + Tensor tSR_cD_flt = filter_zeros(tSR_cD, tSR_gBias.stride()); + + // Step 0. Copy Bias from GMEM to fragment + auto pred_fn = [&] (auto const&... coords) { return elem_less(tSR_cD_flt(coords...), take<0, 2>(residue_mnk)); }; + copy_if(pred_fn, tSR_gBias_flt, tSR_rBias_flt); + } + } + + // For each tiling needed for SmemLayout to cover shape(gD) + CUTLASS_PRAGMA_UNROLL + for (int step_m = 0; step_m < size<2>(cDt); ++step_m) { + CUTLASS_PRAGMA_UNROLL + for (int step_n = 0; step_n < size<3>(cDt); ++step_n) { + // Step 1. Copy to SMEM + CUTLASS_PRAGMA_UNROLL + for (int pipe_m = 0; pipe_m < size<1>(tRS_sAcc); ++pipe_m) { + CUTLASS_PRAGMA_UNROLL + for (int pipe_n = 0; pipe_n < size<2>(tRS_sAcc); ++pipe_n) { + int mma_m = step_m * size<1>(tRS_sAcc) + pipe_m; + int mma_n = step_n * size<2>(tRS_sAcc) + pipe_n; + + copy(tiled_r2s, tRS_rAcc(_,mma_m,mma_n), tRS_sAcc(_,pipe_m,pipe_n)); + } + } + + // Step 2. Wait for SMEM writes to complete + synchronize(); + + // Step 3. Copy from SMEM into a fragment + copy(tiled_s2r, tSR_sAcc, tSR_rAcc); + + // Step 4. Wait for SMEM reads to complete + synchronize(); + + Tensor tSR_gDmn = tSR_gD(_,_,_,step_m,step_n); + Tensor tSR_cDmn = tSR_cD(_,_,_,step_m,step_n); + + if constexpr (IsEpilogueBiasSupported) { + Tensor tSR_rBiasmn = tSR_rBias(_,_,_,step_m,step_n); + + if (epilogue_op.is_source_needed()) { + // source is needed + Tensor tSR_gCmn = tSR_gC(_,_,_,step_m,step_n); + + // Step 5. Copy C from GMEM to a fragment + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < size<1>(tSR_gDmn); ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size<2>(tSR_gDmn); ++n) { + // Predication + if (elem_less(tSR_cDmn(0,m,n), take<0,2>(residue_mnk))) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(tSR_rAcc); ++i) { + tSR_rC(i,m,n) = tSR_gCmn(i,m,n); + } + } + } + } + + // Step 6. Elementwise operation with conversion + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tSR_rAcc); ++i) { + if constexpr (IsActHasArgs) { + epilogue_op(tSR_rD(i), tSR_rD(i), tSR_rAcc(i), tSR_rC(i), tSR_rBiasmn(i), params.activation); + } else { + epilogue_op(tSR_rD(i), tSR_rD(i), tSR_rAcc(i), tSR_rC(i), tSR_rBiasmn(i)); + } + } + } + else { + // source is not needed, avoid load and lift compute + + // Step 5. Elementwise operation with conversion + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tSR_rAcc); ++i) { + if constexpr (IsActHasArgs) { + epilogue_op(tSR_rD(i), tSR_rD(i), tSR_rAcc(i), tSR_rBiasmn(i), params.activation); + } else { + epilogue_op(tSR_rD(i), tSR_rD(i), tSR_rAcc(i), tSR_rBiasmn(i)); + } + } + } + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < size<1>(tSR_gDmn); ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size<2>(tSR_gDmn); ++n) { + // Predication + if (elem_less(tSR_cDmn(0,m,n), take<0,2>(residue_mnk))) { + // The Last Step. Copy to GMEM + copy(CopyAtomR2G{}, tSR_rD(_,m,n), tSR_gDmn(_,m,n)); + } + } + } + } else { + if (epilogue_op.is_source_needed()) { + // source is needed + Tensor tSR_gCmn = tSR_gC(_,_,_,step_m,step_n); + + // Step 5. Copy C from GMEM to a fragment + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < size<1>(tSR_gDmn); ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size<2>(tSR_gDmn); ++n) { + // Predication + if (elem_less(tSR_cDmn(0,m,n), take<0,2>(residue_mnk))) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(tSR_rAcc); ++i) { + tSR_rC(i,m,n) = tSR_gCmn(i,m,n); + } + } + } + } + + // Step 6. Elementwise operation with conversion + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tSR_rAcc); ++i) { + tSR_rD(i) = epilogue_op(tSR_rAcc(i), tSR_rC(i)); + } + } + else { + // source is not needed, avoid load and lift compute + + // Step 5. Elementwise operation with conversion + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tSR_rAcc); ++i) { + tSR_rD(i) = epilogue_op(tSR_rAcc(i)); + } + } + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < size<1>(tSR_gDmn); ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size<2>(tSR_gDmn); ++n) { + // Predication + if (elem_less(tSR_cDmn(0,m,n), take<0,2>(residue_mnk))) { + // The Last Step. Copy to GMEM + copy(CopyAtomR2G{}, tSR_rD(_,m,n), tSR_gDmn(_,m,n)); + } + } + } + } + } + } + } + +private: + Params params; + ThreadEpilogueOp epilogue_op; +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/collective/sm70_epilogue_vectorized_array.hpp b/include/cutlass/epilogue/collective/sm70_epilogue_vectorized_array.hpp new file mode 100644 index 0000000000..8a70370b21 --- /dev/null +++ b/include/cutlass/epilogue/collective/sm70_epilogue_vectorized_array.hpp @@ -0,0 +1,412 @@ +/*************************************************************************************************** + * Copyright (c) 2024 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing elementwise operations used by epilogues. +*/ + +#pragma once + +#include "cutlass/epilogue/collective/sm70_epilogue_vectorized.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// +/// Ptr Array Epilogue Vectorized +/// Applies an element wise operation to all elements within the fragment +/// and writes it out to destination storage. +/// +/// Ways to generalize this: +/// - CTA tile shape +/// - vectorization requirements (GMEM) +/// - vectoriz(able) transform() +/// +template < + class StrideC_, + class StrideD_, + class ThreadEpilogueOp_, + class SmemLayout_, + class CopyAtomR2S_, + class TiledCopyS2R_, + class CopyAtomR2G_, + class EpilogueScheduleType_ +> +class Epilogue< + StrideC_, + StrideD_, + ThreadEpilogueOp_, + SmemLayout_, + CopyAtomR2S_, + TiledCopyS2R_, + CopyAtomR2G_, + EpilogueScheduleType_, + cute::enable_if_t< + cute::is_same_v + > + > { +public: + // + // Type Aliases + // + // derived types of output thread level operator + using ThreadEpilogueOp = ThreadEpilogueOp_; + using ElementAccumulator = typename ThreadEpilogueOp::ElementAccumulator; + using ElementCompute = typename ThreadEpilogueOp::ElementCompute; + using ElementScalar = ElementCompute; + using ElementOutput = typename ThreadEpilogueOp::ElementOutput; + using ElementC = typename ThreadEpilogueOp::ElementC; + using StrideC = StrideC_; + using InternalStrideC = cute::remove_pointer_t; + using ElementD = typename ThreadEpilogueOp::ElementD; + using StrideD = StrideD_; + using InternalStrideD = cute::remove_pointer_t; + + using SmemLayout = SmemLayout_; + using CopyAtomR2S = CopyAtomR2S_; + using TiledCopyS2R = TiledCopyS2R_; + using CopyAtomR2G = CopyAtomR2G_; + + using GmemTiledCopyC = TiledCopyS2R; + using GmemTiledCopyD = TiledCopyS2R; + + static const int kOutputAlignment = ThreadEpilogueOp::kCount; + + using AlignmentType = typename cute::uint_bit::value * kOutputAlignment>::type; + + static_assert(cute::rank(InternalStrideC{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + static_assert(cute::rank(InternalStrideD{}) == 3, "StrideCD must be rank-3: [M, N, L]"); + + struct SharedStorage + { + cute::array_aligned> smem_epilogue; + }; + + using TensorMapStorage = SharedStorage; + + // Host side epilogue arguments + struct Arguments { + typename ThreadEpilogueOp::Params thread{}; + ElementC const** ptr_C = nullptr; + StrideC dC{}; + ElementD** ptr_D = nullptr; + StrideD dD{}; + }; + + // Device side epilogue params + using Params = Arguments; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments( + ProblemShape const&, + Arguments const& args, + [[maybe_unused]] void* workspace) { + return args; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + template + static bool + can_implement( + [[maybe_unused]] ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + return true; + } + + CUTLASS_HOST_DEVICE + Epilogue(Params const& params_) + : params(params_) { } + + CUTLASS_DEVICE + bool + is_source_needed() { + // For Ptr-Array or Grouped Gemm we cannot determine if source is needed based on first beta. + return true; + } + + template< + class ProblemShapeMNKL, + class BlockShapeMNK, + class BlockCoordMNKL, + class FrgEngine, class FrgLayout, + class TiledMma, + class ResidueMNK + > + CUTLASS_DEVICE void + operator()( + ProblemShapeMNKL problem_shape_mnkl, + BlockShapeMNK blk_shape_MNK, + BlockCoordMNKL blk_coord_mnkl, + cute::Tensor const& accumulators, // (MMA,MMA_M,MMA_N) + TiledMma tiled_mma, + ResidueMNK residue_mnk, + int thread_idx, + char* smem_buf) { + using namespace cute; + using X = Underscore; + + static_assert(cute::rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(is_static::value, "ThreadBlock tile shape must be static"); + static_assert(cute::rank(BlockShapeMNK{}) == 3, "BlockShapeMNK must be rank 3"); + static_assert(cute::rank(BlockCoordMNKL{}) == 4, "BlockCoordMNKL must be rank 3"); + + // synchronizing function for smem reads/writes +#if CUDA_BARRIER_ENABLED + auto synchronize = [] () { cutlass::arch::NamedBarrier::sync(typename TiledCopyS2R::TiledNumThr{}, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; +#else + auto synchronize = [] () { __syncthreads(); }; +#endif + + // Separate out problem shape for convenience + auto M = get<0>(problem_shape_mnkl); + auto N = get<1>(problem_shape_mnkl); + auto L = get<3>(problem_shape_mnkl); + // Batches are managed by using appropriate pointers to C and D matrices + const int32_t mock_L = 1; + const int32_t mock_l_coord = 0; + // Slice to get the tile this CTA is responsible for + auto [m_coord, n_coord, k_coord, l_coord] = blk_coord_mnkl; + + // If scalar alpha/beta are provided, i.e., same alpha/beta applies to all batches/groups. + // If pointers to alpha/beta are provided, i.e., alpha/beta can differ between batches/groups, + // we get the correct alpha/beta values for the current batch/group using group index. + ThreadEpilogueOp epilogue_op = ThreadEpilogueOp(params.thread, l_coord); + + if (epilogue_op.is_source_needed() && params.dC == nullptr) { + // Beta value is non-zero while pointer to C is a nullptr + assert(0); + } + + InternalStrideC stride_c; + InternalStrideD stride_d; + if constexpr (!cute::is_same_v) { + // If grouped gemm + if (epilogue_op.is_source_needed()) { + stride_c = params.dC[l_coord]; + } + stride_d = params.dD[l_coord]; + } + else { + stride_c = params.dC; + stride_d = params.dD; + } + + // Represent the full output tensor + ElementC const* ptr_C_l = nullptr; + if (epilogue_op.is_source_needed()) { + ptr_C_l = params.ptr_C[l_coord]; + } + Tensor mC_mnl = make_tensor(make_gmem_ptr(ptr_C_l), make_shape(M,N,mock_L), stride_c); // (m,n,l) + Tensor mD_mnl = make_tensor(make_gmem_ptr(params.ptr_D[l_coord]), make_shape(M,N,mock_L), stride_d); // (m,n,l) + Tensor gC_mnl = local_tile(mC_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + Tensor gD_mnl = local_tile(mD_mnl, blk_shape_MNK, make_coord(_,_,_), Step<_1,_1, X>{}); // (BLK_M,BLK_N,m,n,l) + + Tensor gC = gC_mnl(_,_,m_coord,n_coord,mock_l_coord); // (BLK_M,BLK_N) + Tensor gD = gD_mnl(_,_,m_coord,n_coord,mock_l_coord); // (BLK_M,BLK_N) + + // Construct a tensor in SMEM that we can partition for rearranging data + SharedStorage& storage = *reinterpret_cast(smem_buf); + Tensor sAcc = make_tensor(make_smem_ptr(storage.smem_epilogue.data()), SmemLayout{}); // (SMEM_M,SMEM_N) + + // Partition sAcc to match the accumulator partitioning + auto tiled_r2s = make_tiled_copy_C(CopyAtomR2S{}, tiled_mma); + auto thread_r2s = tiled_r2s.get_thread_slice(thread_idx); + Tensor tRS_rAcc = thread_r2s.retile_S(accumulators); // ((Atom,AtomNum), MMA_M, MMA_N) + Tensor tRS_sAcc = thread_r2s.partition_D(sAcc); // ((Atom,AtomNum),PIPE_M,PIPE_N) + + // Tile gD and gC by the shape of SmemLayout first + auto tile = make_shape(size<0>(sAcc), size<1>(sAcc)); + Tensor gCt = flat_divide(gC, tile); // (SMEM_M,SMEM_N,TILE_M,TILE_N) + Tensor gDt = flat_divide(gD, tile); // (SMEM_M,SMEM_N,TILE_M,TILE_N) + + // Partition sAcc, gC, and gD for the output + auto tiled_s2r = TiledCopyS2R{}; + auto thread_s2r = tiled_s2r.get_thread_slice(thread_idx); + Tensor tSR_sAcc = thread_s2r.partition_S(sAcc); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tSR_gC = thread_s2r.partition_D(gCt); // ((Atom,AtomNum),ATOM_M,ATOM_N,TILE_M,TILE_N) + Tensor tSR_gD = thread_s2r.partition_D(gDt); // ((Atom,AtomNum),ATOM_M,ATOM_N,TILE_M,TILE_N) + + // Allocate intermediate registers on the dst tensors + Tensor tSR_rAcc = make_tensor(take<0,3>(shape(tSR_gC))); // ((Atom,AtomNum),ATOM_M,ATOM_N) + Tensor tSR_rD = make_tensor(shape(tSR_rAcc)); // ((Atom,AtomNum),ATOM_M,ATOM_N) + + // Repeat the D-partitioning for coordinates and predication + Tensor cD = make_identity_tensor(make_shape(size<0>(gD),size<1>(gD))); // (BLK_M,BLK_N) -> (blk_m,blk_n) + Tensor cDt = flat_divide(cD, tile); // (SMEM_M,SMEM_N,TILE_M,TILE_N) + Tensor tSR_cD = thread_s2r.partition_D(cDt); // ((Atom,AtomNum),ATOM_M,ATOM_N,TILE_M,TILE_N) + + CUTE_STATIC_ASSERT(size<1>(tRS_rAcc) % size<3>(tSR_gC) == 0); // TILE_M divides MMA_M + CUTE_STATIC_ASSERT(size<2>(tRS_rAcc) % size<4>(tSR_gC) == 0); // TILE_N divides MMA_N + +#if 0 + if (thread_idx == 0 && m_coord == 0 && n_coord == 0) { + print("aC : "); print(accumulators.layout()); print("\n"); + print("gC : "); print(gC.layout()); print("\n"); + print("gD : "); print(gD.layout()); print("\n"); + print("sAcc : "); print(sAcc.layout()); print("\n"); + print("\n"); + print("tRS_sAcc : "); print(tRS_sAcc.layout()); print("\n"); + print("tRS_rAcc : "); print(tRS_rAcc.layout()); print("\n"); + print("\n"); + print("gDt : "); print(gDt.layout()); print("\n"); + print("tSR_sAcc : "); print(tSR_sAcc.layout()); print("\n"); + print("tSR_rAcc : "); print(tSR_rAcc.layout()); print("\n"); + print("\n"); + print("tSR_rD : "); print(tSR_rD.layout()); print("\n"); + print("tSR_gC : "); print(tSR_gC.layout()); print("\n"); + print("tSR_gD : "); print(tSR_gD.layout()); print("\n"); + print("\n"); + } +#endif + + // For each tiling needed for SmemLayout to cover shape(gD) + CUTLASS_PRAGMA_UNROLL + for (int step_m = 0; step_m < size<2>(cDt); ++step_m) { + CUTLASS_PRAGMA_UNROLL + for (int step_n = 0; step_n < size<3>(cDt); ++step_n) { + // Step 1. Copy to SMEM + CUTLASS_PRAGMA_UNROLL + for (int pipe_m = 0; pipe_m < size<1>(tRS_sAcc); ++pipe_m) { + CUTLASS_PRAGMA_UNROLL + for (int pipe_n = 0; pipe_n < size<2>(tRS_sAcc); ++pipe_n) { + int mma_m = step_m * size<1>(tRS_sAcc) + pipe_m; + int mma_n = step_n * size<2>(tRS_sAcc) + pipe_n; + + copy(tiled_r2s, tRS_rAcc(_,mma_m,mma_n), tRS_sAcc(_,pipe_m,pipe_n)); + } + } + + // Step 2. Wait for SMEM writes to complete + synchronize(); + + // Step 3. Copy from SMEM into a fragment + copy(tiled_s2r, tSR_sAcc, tSR_rAcc); + + // Step 4. Wait for SMEM reads to complete + synchronize(); + + Tensor tSR_gDmn = tSR_gD(_,_,_,step_m,step_n); + Tensor tSR_cDmn = tSR_cD(_,_,_,step_m,step_n); + + if (epilogue_op.is_source_needed()) { + // source is needed + Tensor tSR_gCmn = tSR_gC(_,_,_,step_m,step_n); + + Tensor tSR_rCmn = make_tensor(shape(tSR_gCmn)); // ((Atom,AtomNum),ATOM_M,ATOM_N) + + // Step 5. Copy C from GMEM to a fragment + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < size<1>(tSR_gDmn); ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size<2>(tSR_gDmn); ++n) { + // Predication + if (elem_less(tSR_cDmn(0,m,n), take<0,2>(residue_mnk))) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(tSR_rAcc); ++i) { + tSR_rCmn(i,m,n) = tSR_gCmn(i,m,n); + } + } + } + } + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < size<1>(tSR_gDmn); ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size<2>(tSR_gDmn); ++n) { + // Predication + if (elem_less(tSR_cDmn(0,m,n), take<0,2>(residue_mnk))) { + // Step 6. Elementwise operation with conversion + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size<0>(tSR_rAcc); ++i) { + tSR_rD(i,m,n) = epilogue_op(tSR_rAcc(i,m,n), tSR_rCmn(i,m,n)); + } + // Step 7. Copy to GMEM + copy(CopyAtomR2G{}, tSR_rD(_,m,n), tSR_gDmn(_,m,n)); + } + } + } + } + else { + // source is not needed, avoid load and lift compute + + // Step 5. Elementwise operation with conversion + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tSR_rAcc); ++i) { + tSR_rD(i) = epilogue_op(tSR_rAcc(i)); + } + + CUTLASS_PRAGMA_UNROLL + for (int m = 0; m < size<1>(tSR_gDmn); ++m) { + CUTLASS_PRAGMA_UNROLL + for (int n = 0; n < size<2>(tSR_gDmn); ++n) { + // Predication + if (elem_less(tSR_cDmn(0,m,n), take<0,2>(residue_mnk))) { + // Step 6. Copy to GMEM + copy(CopyAtomR2G{}, tSR_rD(_,m,n), tSR_gDmn(_,m,n)); + } + } + } + } + } + } + } + +private: + Params params; +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp b/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp new file mode 100644 index 0000000000..54fe9b1daf --- /dev/null +++ b/include/cutlass/epilogue/collective/sm90_epilogue_array_tma_warpspecialized.hpp @@ -0,0 +1,1198 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing elementwise operations used by epilogues. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/barrier.h" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/epilogue/thread/scale_type.h" +#include "cutlass/epilogue/fusion/callbacks.hpp" +#include "cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp" +#include "cutlass/detail/collective.hpp" +#include "cutlass/detail/layout.hpp" +#include "cutlass/trace.h" +#include "cutlass/cuda_host_adapter.hpp" + +#include "cute/tensor.hpp" +#include "cute/atom/copy_traits_sm90_tma.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + int StagesC_, + int StagesD_, + int FragmentSize_, + bool ReuseSmemC_, + bool DelayTmaStore_, + int NumEpilogueWarpGroups_, + class CtaTileMNK_, // (CTA_M,CTA_N,CTA_K) + class EpilogueTile_, // (EPI_TILE_M,EPI_TILE_N) + class ElementC_, + class StrideC_, + class ElementD_, + class StrideD_, + class FusionCallbacks_, + class CopyOpG2S_, + class SmemLayoutAtomC_, + class CopyOpS2R_, + class CopyOpS2G_, + class SmemLayoutAtomD_, + class CopyOpR2S_, + class CopyAtomC_, + class CopyOpR2R_ +> +class CollectiveEpilogue< + Sm90PtrArrayTmaWarpSpecialized, + CtaTileMNK_, + EpilogueTile_, + ElementC_, + StrideC_, + ElementD_, + StrideD_, + FusionCallbacks_, + CopyOpG2S_, + SmemLayoutAtomC_, + CopyOpS2R_, + CopyOpS2G_, + SmemLayoutAtomD_, + CopyOpR2S_, + CopyAtomC_, + CopyOpR2R_ +> { +public: + // + // Type Aliases + // + using DispatchPolicy = Sm90PtrArrayTmaWarpSpecialized; + using CtaTileMNK = CtaTileMNK_; + using EpilogueTile = EpilogueTile_; + using FusionCallbacks = FusionCallbacks_; + using ElementC = ElementC_; + using StrideC = StrideC_; + using InternalStrideC = cute::remove_pointer_t; + using ElementD = ElementD_; + using StrideD = StrideD_; + using InternalStrideD = cute::remove_pointer_t; + using CopyOpG2S = CopyOpG2S_; + using SmemLayoutAtomC = SmemLayoutAtomC_; + using CopyOpS2R = CopyOpS2R_; + using CopyOpS2G = CopyOpS2G_; + using SmemLayoutAtomD = SmemLayoutAtomD_; + using CopyOpR2S = CopyOpR2S_; + using CopyAtomC = CopyAtomC_; + using CopyOpR2R = CopyOpR2R_; + + using ThreadEpilogueOp = typename epilogue::fusion::FusionCallbacksTraits::Operation; + using GmemTiledCopyC = CopyOpG2S; + using GmemTiledCopyD = CopyOpS2G; + + static_assert(!is_layout::value && is_tuple::value, "EpilogueTile must be a cute::Tile or cute::Shape"); + static_assert(cute::rank(CtaTileMNK{}) == 3, "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]"); + static_assert(cute::rank(EpilogueTile{}) == 2, "EpilogueTile must be rank-2: [EPI_TILE_M, EPI_TILE_N]"); + static_assert(size<0>(CtaTileMNK{}) % size<0>(shape(EpilogueTile{})) == 0, "EPI_TILE_M must divide CTA_M"); + static_assert(size<1>(CtaTileMNK{}) % size<1>(shape(EpilogueTile{})) == 0, "EPI_TILE_N must divide CTA_N"); + static_assert(cute::rank(InternalStrideC{}) == 3, "StrideC must be rank-3: [M, N, L]"); + static_assert(cute::rank(InternalStrideD{}) == 3, "StrideD must be rank-3: [M, N, L]"); + +private: + constexpr static bool is_source_supported = not cute::is_void_v; + constexpr static bool is_destination_supported = not cute::is_void_v; + using NonVoidElementD = cute::conditional_t, ElementD>; + static_assert(not cute::is_void_v, "SmemElementD is void"); + using NonVoidElementC = cute::conditional_t; // prevents void ref breakages + + using SmemElementC = typename cutlass::detail::get_unpacked_element_type::type; + using SmemElementD = typename cutlass::detail::get_unpacked_element_type::type; + + constexpr static int StagesC = StagesC_; + constexpr static int StagesD = StagesD_; + constexpr static bool ReuseSmemC = ReuseSmemC_ and is_destination_supported; + constexpr static bool DelayTmaStore = DelayTmaStore_; + + constexpr static bool is_m_major_C = detail::is_m_major(); + constexpr static bool is_m_major_D = detail::is_m_major(); + + constexpr static bool is_im2col_C = cute::is_same_v; + constexpr static bool is_im2col_D = cute::is_same_v; + + // Check if register transformation is needed before copying register to shared memory. + constexpr static bool IsUseR2R = !cute::is_void_v; + + using SmemLayoutC = decltype(tile_to_shape( + SmemLayoutAtomC{}, + make_shape(size<0>(EpilogueTile{}), size<1>(EpilogueTile{}), Int{}), + cute::conditional_t, Step<_1,_2,_3>>{} )); + using SmemLayoutD = decltype(tile_to_shape( + SmemLayoutAtomD{}, + make_shape(size<0>(EpilogueTile{}), size<1>(EpilogueTile{}), Int{}), + cute::conditional_t, Step<_1,_2,_3>>{} )); + + constexpr static bool support_smem_reuse = is_source_supported && is_destination_supported && StagesD <= StagesC + && cosize(take<0,2>(SmemLayoutC{})) == cosize(take<0,2>(SmemLayoutD{})); + static_assert(not (ReuseSmemC && not support_smem_reuse), "Smem reuse requirements not met"); + + constexpr static size_t SmemAlignmentD = cutlass::detail::alignment_for_swizzle(SmemLayoutD{}); + constexpr static size_t SmemAlignmentC = cutlass::detail::alignment_for_swizzle(SmemLayoutC{}); + constexpr static size_t MaxSmemAlignment = cute::max(SmemAlignmentC, SmemAlignmentD); + + using SmemArrayTypeC = cute::ArrayEngine>; + using SmemArrayTypeD = cute::ArrayEngine>; + + using EmptyType = cute::tuple<>; + using SmemCStorage = cute::conditional_t; + using SmemDStorage = cute::conditional_t; + + struct CollectiveStorageWithC { + alignas(SmemAlignmentC) ArrayEngine> smem_C; + alignas(SmemAlignmentD) ArrayEngine> smem_D; + }; + + union CollectiveStorageWithoutC { + cute::array smem_C; + alignas(SmemAlignmentD) ArrayEngine> smem_D; + }; + + union CollectiveStorageReuseC { + alignas(MaxSmemAlignment) ArrayEngine> smem_C; + alignas(MaxSmemAlignment) ArrayEngine> smem_D; + }; + +public: + // TMA pipeline for loading C + using LoadPipeline = cutlass::PipelineTransactionAsync; + using LoadPipelineState = cutlass::PipelineState; + constexpr static uint32_t TmaTransactionBytes = + (size(take<0,2>(SmemLayoutC{})) * static_cast(sizeof_bits::value)) / 8; + constexpr static bool RequiresTransactionBytes = true; + + constexpr static int NumEpilogueWarpGroups = NumEpilogueWarpGroups_; + + // TMA pipeline for storing D + using StorePipeline = cute::conditional_t, + cutlass::PipelineTmaStore>; + using StorePipelineState = cutlass::PipelineState; + + struct SharedStorage { + struct TensorStorage { + using CollectiveStorage = cute::conditional_t>; + CollectiveStorage collective; + + using FusionStorage = typename FusionCallbacks::SharedStorage; + FusionStorage thread; + } tensors; + + struct TensorMapStorage : cute::aligned_struct<128, _0> { + cute::TmaDescriptor smem_tensormap_C; + cute::array smem_tensormap_D; + } tensormaps; + + using PipelineStorage = typename LoadPipeline::SharedStorage; + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using TensorMapStorage = typename SharedStorage::TensorMapStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + static constexpr bool IsGroupedGemmKernel = !cute::is_same_v; + + // Host side epilogue arguments + struct Arguments { + typename FusionCallbacks::Arguments thread{}; + ElementC const** ptr_C = nullptr; + StrideC dC; + ElementD ** ptr_D = nullptr; + StrideD dD; + }; + + // Device side epilogue params + struct Params { + using TMA_C = decltype(make_tma_copy( + CopyOpG2S{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), + repeat_like(InternalStrideC{}, int32_t(0)), InternalStrideC{}), + take<0,2>(SmemLayoutC{}), + EpilogueTile{}, + _1{})); + + using TMA_D = decltype(make_tma_copy( + CopyOpS2G{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), + repeat_like(InternalStrideD{}, int32_t(0)), InternalStrideD{}), + take<0,2>(SmemLayoutD{}), + EpilogueTile{}, + _1{})); + + typename FusionCallbacks::Params thread{}; + TMA_C tma_load_c; + TMA_D tma_store_d; + cute::TmaDescriptor* tensormaps; + ElementC const** ptr_C; + StrideC dC; + ElementD** ptr_D; + StrideD dD; + uint32_t tma_transaction_bytes = TmaTransactionBytes; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + [[maybe_unused]] void* workspace) { + // These tensor shapes (only applicable for grouped gemm) and pointers are only used to create tensormap/tma desc. + // These will be replaced with correct values before the initial tma load. + auto init_shape = repeat_like(append<4>(typename ProblemShape::UnderlyingProblemShape{}, 1), int32_t(1)); + auto init_M = get<0>(init_shape); + auto init_N = get<1>(init_shape); + auto init_L = get<3>(init_shape); + + static_assert(!is_im2col_C and !is_im2col_D, "Im2Col not supported on C or D"); + + InternalStrideC stride_c; + InternalStrideD stride_d; + if constexpr (IsGroupedGemmKernel) { + // Strides for Grouped Gemm will be replaced prior to the first access regardless. + stride_c = InternalStrideC{}; + stride_d = InternalStrideD{}; + } + else { + // Tensor shapes for Ptr-Array are initialized correctly only here. + auto problem_shape_MNKL = append<4>(problem_shape.get_host_problem_shape(0), 1); + init_M = get<0>(problem_shape_MNKL); + init_N = get<1>(problem_shape_MNKL); + init_L = get<3>(problem_shape_MNKL); + + stride_c = args.dC; + stride_d = args.dD; + } + + uint32_t transaction_bytes = TmaTransactionBytes; + typename Params::TMA_C tma_load_c{}; + if constexpr (is_source_supported) { + ElementC const* ptr_C_first_batch = reinterpret_cast(args.ptr_C); + Tensor tensor_c = make_tensor(ptr_C_first_batch, make_layout(make_shape(init_M,init_N,init_L), append<3>(stride_c, _0{}))); + tma_load_c = make_tma_copy( + CopyOpG2S{}, + tensor_c, + take<0,2>(SmemLayoutC{}), + EpilogueTile{}, + _1{}); + } + + typename Params::TMA_D tma_store_d; + if constexpr (is_destination_supported) { + ElementD const* ptr_D_first_batch = reinterpret_cast(args.ptr_D); + Tensor tensor_d = make_tensor(ptr_D_first_batch, make_layout(make_shape(init_M,init_N,init_L), append<3>(stride_d, _0{}))); + tma_store_d = make_tma_copy( + CopyOpS2G{}, + tensor_d, + take<0,2>(SmemLayoutD{}), + EpilogueTile{}, + _1{}); + } + + auto fusion_workspace = static_cast(workspace); + auto fusion_workspace_size = FusionCallbacks::get_workspace_size(problem_shape, args.thread); + auto tma_descriptor_workspace = reinterpret_cast( + static_cast(workspace) + fusion_workspace_size); + + return { + FusionCallbacks::to_underlying_arguments(problem_shape, args.thread, fusion_workspace), + tma_load_c, + tma_store_d, + tma_descriptor_workspace, + args.ptr_C, + args.dC, + args.ptr_D, + args.dD, + transaction_bytes, + }; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args, int sm_count) { + + constexpr uint32_t NumInputTensors = NumEpilogueWarpGroups + (cute::is_void_v ? 0 : 1); + auto descriptors_shape = cute::make_shape(sm_count, Int{}); + constexpr size_t SizeOfCuTensorMap = sizeof(cute::TmaDescriptor); + + // Allocate gmem space for input tensormaps per each SM, A tensormap copies followed by B tensormap copies + return (size(descriptors_shape) * SizeOfCuTensorMap) + FusionCallbacks::get_workspace_size(problem_shape, args.thread); + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return FusionCallbacks::initialize_workspace(problem_shape, args.thread, workspace, stream, cuda_adapter); + } + + template + static bool + can_implement( + ProblemShape problem_shape, + [[maybe_unused]] Arguments const& args) { + + bool implementable = true; + bool fusion_implementable = true; + + if (problem_shape.is_host_problem_shape_available()) { + for (int i = 0; i < problem_shape.groups(); ++i) { + auto problem_shape_MNKL = append<4>(problem_shape.get_host_problem_shape(i), 1); + auto [M,N,K,L] = problem_shape_MNKL; + + if constexpr (is_destination_supported) { + constexpr int tma_alignment_bits_D = cutlass::detail::get_output_alignment_bits(); + constexpr int min_tma_aligned_elements_D = tma_alignment_bits_D / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,N,L), InternalStrideD{}); + } + + if constexpr (is_source_supported) { + constexpr int tma_alignment_bits_C = cutlass::detail::get_input_alignment_bits(); + constexpr int min_tma_aligned_elements_C = tma_alignment_bits_C / cutlass::sizeof_bits::value; + implementable = implementable && cutlass::detail::check_alignment(cute::make_shape(M,N,L), InternalStrideC{}); + } + + fusion_implementable = fusion_implementable && FusionCallbacks::can_implement(problem_shape_MNKL, args.thread); + } + } + else { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Ignoring check to can implement because host problem shape is not available.\n"); + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + + if (!fusion_implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum requirements for FusionCallbacks.\n"); + } + + bool beta_implementable = true; + + if (cute::is_void_v || args.ptr_C == nullptr) { + if constexpr (detail::has_beta::value) { + beta_implementable = args.thread.beta == 0.0; + } + if constexpr (detail::has_beta_ptr::value) { + beta_implementable = beta_implementable && args.thread.beta_ptr == nullptr; + } + if constexpr (detail::has_beta_ptr_array::value) { + beta_implementable = beta_implementable && args.thread.beta_ptr_array == nullptr; + } + } + + if (!beta_implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Beta/beta pointer was set, but epilogue is sourceless (void-C).\n"); + } + + return implementable && fusion_implementable && beta_implementable; + } + + template + CUTLASS_HOST_DEVICE + static constexpr int + get_load_pipe_increment(TileShapeMNK tile_shape_MNK) { + // Compute number of epilogue subtiles + return size<1>(zipped_divide(make_layout(take<0,2>(tile_shape_MNK)), EpilogueTile{})); + } + + template + CUTLASS_HOST_DEVICE + static constexpr int + get_store_pipe_increment(TileShapeMNK tile_shape_MNK) { + return get_load_pipe_increment(tile_shape_MNK); + } + + CUTLASS_HOST_DEVICE + CollectiveEpilogue(Params const& params_, TensorStorage& shared_tensors) + : params(params_), fusion_callbacks(params_.thread, shared_tensors.thread) {} + + CUTLASS_DEVICE + bool + is_producer_load_needed() const { + return fusion_callbacks.is_producer_load_needed(); + } + + CUTLASS_DEVICE auto + load_init( + Params const& params, + TensorMapStorage& shared_tensormaps, + int32_t sm_count, + int32_t sm_idx) { + // Initialize tma for loading + constexpr bool IsLoad = true; + auto load_tensormaps = tensormaps_init(params, shared_tensormaps, sm_count, sm_idx, 0); + return load_tensormaps; + } + + template< + class ProblemShapeMNKL, + class TileShapeMNK, + class TileCoordMNKL, + class TiledMma, + class TensorMapC, + __CUTE_REQUIRES(std::is_pointer_v) + > + CUTLASS_DEVICE auto + load( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_producer_state, + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_MNK, + TileCoordMNKL tile_coord_mnkl, + TiledMma tiled_mma, + int thread_idx, + TensorStorage& shared_tensors, + TensorMapC const& load_tensormap, + int subtile_idx=-1, + bool wait_until_load_finishes = false) { + using namespace cute; + + // Indexing variables + auto [M, N, K, L] = problem_shape_mnkl; + auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl; + + static_assert(!is_im2col_D, "Do not support im2col"); + + auto coord_shape = append<3>(make_shape(m_coord, n_coord), Int<0>{}); + + // Represent the full source tensor, slice to get the tile this CTA is currently responsible for + Tensor mC_mn = params.tma_load_c.get_tma_tensor(append<3>(make_shape(M,N), Int<1>{})); // (M,N,L) + Tensor mC = coalesce(mC_mn, take<0,2>(CtaTileMNK{})); + Tensor gC = local_tile(mC, take<0,2>(CtaTileMNK{}), coord_shape); // (CTA_M,CTA_N) + + // Apply epilogue subtile, get matching smem tensor + auto ptr_sC = shared_tensors.collective.smem_C.begin(); + Tensor gC_epi = flat_divide(gC, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + Tensor sC_epi = make_tensor(make_smem_ptr(ptr_sC), SmemLayoutC{}); // (EPI_TILE_M,EPI_TILE_N,PIPE_C) + + // Prepare the thread(b)lock's (G)mem to (S)mem TMA tiled copy (bGS_) + ThrCopy thrblk_g2s = params.tma_load_c.get_slice(Int<0>{}); + Tensor bGS_gC = thrblk_g2s.partition_S(gC_epi); // (G2S,G2S_M,G2S_N,EPI_M,EPI_N) + Tensor bGS_sC = thrblk_g2s.partition_D(sC_epi); // (G2S,G2S_M,G2S_N,PIPE_C) + + // Get the fusion callbacks for the producer load warp + auto pld_args = cutlass::epilogue::fusion::detail::ProducerLoadArgs{ + problem_shape_mnkl, + CtaTileMNK{}, + tile_coord_mnkl, + tiled_mma, + EpilogueTile{}, + thread_idx + }; + auto pld_callbacks = fusion_callbacks.get_producer_load_callbacks(pld_args); + bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); + + LoadPipelineState last_load_producer_state = load_pipe_producer_state; + + // Predication for TMA load (one thread issues TMA load) + bool issue_tma_load = cute::elect_one_sync(); + + // Pre-loop fusion callback entry point + pld_callbacks.begin(); + + LoadPipelineState prior_state = load_pipe_producer_state; + + bool did_load = false; + + CUTLASS_PRAGMA_UNROLL + for (int epi_n = 0; epi_n < size<3>(gC_epi); ++epi_n) { + CUTLASS_PRAGMA_UNROLL + for (int epi_m = 0; epi_m < size<2>(gC_epi); ++epi_m) { + if (subtile_idx != -1 && (epi_n * static_cast(size<2>(gC_epi)) + epi_m) != subtile_idx) { + continue; + } + + // Acquire the lock for this stage + constexpr uint16_t mcast_mask = 0; + uint64_t* tma_barrier = load_pipeline.producer_get_barrier(load_pipe_producer_state); + + load_pipeline.producer_acquire(load_pipe_producer_state); + + // Loop fusion callback entry point + pld_callbacks.step(tma_barrier, epi_m, epi_n, load_pipe_producer_state.count(), issue_tma_load); + + // Execute the TMA load for C if needed + if (is_C_load_needed) { + if (issue_tma_load) { + copy(params.tma_load_c.with(load_tensormap, *tma_barrier, mcast_mask), + bGS_gC(_,_,_,epi_m,epi_n), bGS_sC(_,_,_,load_pipe_producer_state.index())); + load_pipeline.producer_expect_transaction(load_pipe_producer_state); + } + last_load_producer_state = load_pipe_producer_state; + did_load = true; + } + + // Commit TMA loads for this stage and release the lock + load_pipeline.producer_commit(load_pipe_producer_state); + ++load_pipe_producer_state; + } + } + + // Post-loop fusion callback entry point + pld_callbacks.end(); + + if (wait_until_load_finishes && did_load) { + typename CollectiveEpilogue::LoadPipelineState epi_load_pipe_tma_consumer_state = + {last_load_producer_state.index(), !last_load_producer_state.phase(), last_load_producer_state.count()}; + load_pipeline.consumer_wait(epi_load_pipe_tma_consumer_state); + } + + return load_pipe_producer_state; + } + + CUTLASS_DEVICE auto + load_tail( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_producer_state) { + + if (!fusion_callbacks.is_producer_load_needed()) { + return load_pipe_producer_state; + } + + bool issue_tma_load = cute::elect_one_sync(); + if (issue_tma_load) { + load_pipeline.producer_tail(load_pipe_producer_state); + } + + return load_pipe_producer_state; + } + + template< + class ProblemShapeMNKL, + class TileShapeMNK, + class TileCoordMNKL, + class AccEngine, class AccLayout, + class TiledMma, + class TensorMapD + > + CUTLASS_DEVICE auto + store( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_consumer_state, + StorePipeline store_pipeline, + StorePipelineState store_pipe_producer_state, + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_MNK, + TileCoordMNKL tile_coord_mnkl, + cute::Tensor accumulators, + TiledMma tiled_mma, + int thread_idx, + TensorStorage& shared_tensors, + TensorMapD const& store_tensormap, + int subtile_idx=-1) { + + using namespace cute; + using ElementAccumulator = typename AccEngine::value_type; + using ElementCompute_ = typename epilogue::fusion::FusionCallbacksTraits::ElementCompute; + using ElementCompute = cute::conditional_t,ElementAccumulator,ElementCompute_>; + + static_assert(is_rmem::value, "Accumulator must be RF resident."); + static_assert(rank(AccLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA,MMA_M,MMA_N)"); + static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(is_static::value, "TileShapeMNK must be static"); + static_assert(rank(TileShapeMNK{}) == 3, "TileShapeMNK must be rank 3"); + static_assert(rank(TileCoordMNKL{}) == 4, "TileCoordMNKL must be rank 4"); + + // Indexing variables + auto [M, N, K, L] = problem_shape_mnkl; + auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl; + + + static_assert(!is_im2col_D, "Do not support im2col"); + + auto coord_shape = append<3>(make_shape(m_coord, n_coord), Int<0>{}); + + // Represent the full output tensor, slice to get the tile this CTA is responsible for + Tensor mD_mn = params.tma_store_d.get_tma_tensor(append<3>(make_shape(M,N), Int<1>{})); // (M,N,L) + + Tensor mD = coalesce(mD_mn, take<0,2>(CtaTileMNK{})); + Tensor gD = local_tile(mD, take<0,2>(CtaTileMNK{}), coord_shape); // (CTA_M,CTA_N) + + // Apply epilogue subtiling + Tensor gD_epi = flat_divide(gD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + + // Construct the corresponding pipelined smem tensors + auto ptr_sC = shared_tensors.collective.smem_C.begin(); + auto ptr_sD = shared_tensors.collective.smem_D.begin(); + Tensor sC_epi = cute::as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(ptr_sC), SmemLayoutC{})); // (EPI_TILE_M,EPI_TILE_N,PIPE_C) + Tensor sD_epi = cute::as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(ptr_sD), SmemLayoutD{})); // (EPI_TILE_M,EPI_TILE_N,PIPE_D) + + TiledCopy tiled_copy_C_atom = make_tiled_copy_C_atom(CopyAtomC{}, tiled_mma); + + // (t)hread-partition for (r)egister to (r)egister copy (tRR_) + TiledCopy tiled_r2r = [&]() { + if constexpr (IsUseR2R) { + return make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); + } + else { + return make_tiled_copy_S(Copy_Atom, + ElementCompute>{}, tiled_copy_C_atom); + } + }(); + ThrCopy thread_r2r = tiled_r2r.get_slice(thread_idx); + + // (t)hread-partition for (r)egister to (s)mem copy (tRS_) + TiledCopy tiled_r2s = [&]() { + if constexpr (IsUseR2R) { + return make_tiled_copy_D(Copy_Atom{}, tiled_r2r); + } + else { + return make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); + } + }(); + ThrCopy thread_r2s = tiled_r2s.get_slice(thread_idx); + Tensor tRS_rAcc = thread_r2s.retile_S(accumulators); // ((R2S,R2S_V),MMA_M,MMA_N) + Tensor tRS_sD = thread_r2s.partition_D(sD_epi); // (R2S,R2S_M,R2S_N,PIPE_D) + + auto mma_tile_m = size<0>(TileShapeMNK{}) / size<1>(tRS_rAcc); + auto mma_tile_n = size<1>(TileShapeMNK{}) / size<2>(tRS_rAcc); + auto epi_tile_m = size<0>(EpilogueTile{}); + auto epi_tile_n = size<1>(EpilogueTile{}); + + // Allocate D registers + Layout tRS_rD_layout = make_layout(take<0,3>(shape(thread_r2s.partition_S(sD_epi)))); + Tensor tRS_rD = make_tensor(tRS_rD_layout); // (R2S,R2S_M,R2S_N) + + // Vectorized fragment view + constexpr int FragmentSize = DispatchPolicy::FragmentSize; + Tensor tRS_rAcc_frg = recast>(tRS_rAcc); + Tensor tRS_rD_frg = recast>(tRS_rD); + CUTE_STATIC_ASSERT(size<0>(tRS_rAcc) % FragmentSize == 0, "Fragment size does not vectorize properly"); + + // (t)hread-partition for (s)mem to (r)egister copy (tSR_) + TiledCopy tiled_s2r = make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); + ThrCopy thread_s2r = tiled_s2r.get_slice(thread_idx); + Tensor tSR_sC = thread_s2r.partition_S(sC_epi); // (S2R,S2R_M,S2R_N,PIPE_C) + Layout tSR_rC_layout = thread_s2r.retile_D(tRS_rD).layout(); // (S2R,S2R_M,S2R_N) + + // Allocate C registers + // If C smem load is a non-vectorized dst(i) = src(i) then we can allocate C registers directly in the compute type + // to eliminate some redundant pack+unpack instruction sequences for sub-word types + constexpr bool IsDirectS2R = cute::is_same_v> + && decltype(max_common_vector(tSR_rC_layout, tSR_sC.layout()))::value <= 1; + using RegisterElementC = cute::conditional_t; + Tensor tRS_rC = make_tensor(tRS_rD_layout); // (R2S,R2S_M,R2S_N) + Tensor tSR_rC = thread_s2r.retile_D(tRS_rC); // (S2R,S2R_M,S2R_N) + + // thread(b)lock-partition for (s)mem to (g)mem copy (bSG_) + ThrCopy thrblk_s2g = params.tma_store_d.get_slice(Int<0>{}); + Tensor bSG_sD = thrblk_s2g.partition_S(sD_epi); // (S2G,S2G_M,S2G_N,PIPE_D) + Tensor bSG_gD = thrblk_s2g.partition_D(gD_epi); // (S2G,S2G_M,S2G_N,EPI_M,EPI_N) + + // OOB predication for tile quantization "residue" + // Absolute coordinate tensors (dynamic) + Tensor mD_crd = make_identity_tensor(make_shape(M,N)); // (M,N) + Tensor cD_mn = local_tile(mD_crd, take<0,2>(CtaTileMNK{}), make_coord(m_coord, n_coord)); // (CTA_M,CTA_N) + Tensor tRS_cD_mn = thread_r2s.partition_S(flat_divide(cD_mn, EpilogueTile{})); // (R2S,R2S_M,R2S_N,EPI_M,EPI_N) + // Relative coordinate tensors (static) + Tensor cD = make_counting_tensor(cD_mn.layout()); // (CTA_M,CTA_N) + Tensor tRS_cD = make_counting_tensor(tRS_cD_mn.layout()); // (R2S,R2S_M,R2S_N,EPI_M,EPI_N) + // Subtract the global "bottom right" corner from the local "top left" corner to get the max relative coordinate + auto residue_cD = make_coord(M,N) - cD_mn(_0{}); // (m,n) + auto residue_tRS_cD = make_coord(M,N) - tRS_cD_mn(_0{}); // (m,n) + + CUTE_STATIC_ASSERT(epi_tile_m % mma_tile_m == 0, "MMA_TILE_M must divide EPI_TILE_M"); + + CUTE_STATIC_ASSERT(mma_tile_n % epi_tile_n == 0, "EPI_TILE_N must divide MMA_TILE_N"); + // Get TiledCopy for partition reference when consumer store. + TiledCopy tiled_copy_partition_ref = make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); + // Get the fusion callbacks for the consumer store warps + constexpr bool RefSrc = true; // Register tensors reference R2S copy src layout + auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs{ + problem_shape_mnkl, + CtaTileMNK{}, + tile_coord_mnkl, + tiled_mma, + EpilogueTile{}, + tiled_copy_partition_ref, + cD, + residue_cD, + tRS_cD, + residue_tRS_cD, + tRS_rC, + thread_idx + }; + auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks(cst_args); + bool is_producer_load_needed = fusion_callbacks.is_producer_load_needed(); + bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); + + // Thread synchronizer for previously issued waits or fences + // to ensure visibility of smem reads/writes to threads or TMA unit + auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; + + // Predication for TMA store (one warp issues TMA store) + bool issue_tma_store = (thread_idx / NumThreadsPerWarp) == 0; + + // In the reuse smem configuration we have StagesC smem buffers and at most StagesD committed TMA stores in flight. + // The TMA store pipeline producer acquire returns when at most StagesD-1 committed stores are in-flight, so we can + // only guarantee store completion after StagesD iterations, then we can begin issuing releases on the smem buffer locks. + // store_pipe_producer_state tracks the acquire and load_pipe_consumer_state tracks the release, in circular buffer fashion. + LoadPipelineState load_wait_state = load_pipe_consumer_state; + if constexpr (ReuseSmemC) { + load_wait_state = store_pipe_producer_state; + load_wait_state.phase_ ^= 1; + } + + // We can delay issue of TMA store by one iteration to achieve better interleaving of non-TMA instructions + // Sync requirements of smem reuse may preclude this optimization + // Delayed stores cause delayed stage releases which causes deadlock when StagesC == StagesD + int epi_m_prev = 0, epi_n_prev = 0; + static_assert(not (DelayTmaStore and ReuseSmemC and StagesC <= StagesD), "This TMA epilogue configuration will deadlock"); + + // The TMA store sequence for one subtile iteration + auto tma_store_fn = [&] (int epi_m, int epi_n) { + // Write the tile from smem to gmem with TMA + cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA + synchronize(); // ensure all threads have issued their async fence + if constexpr (is_destination_supported) { + if (issue_tma_store) { + copy(params.tma_store_d.with(store_tensormap), bSG_sD(_,_,_,store_pipe_producer_state.index()), bSG_gD(_,_,_,epi_m,epi_n)); + } + } + + // Post async fence, pre TMA commit callback entry point + cst_callbacks.tma_store(epi_m, epi_n, store_pipe_producer_state.count(), issue_tma_store); + + // Commit the TMA stores for this stage + if (issue_tma_store) { + store_pipeline.producer_commit(store_pipe_producer_state); + } + ++store_pipe_producer_state; + ++issued_stores; + + // Wait for the next smem buffer to be available + if (issue_tma_store) { + store_pipeline.producer_acquire(store_pipe_producer_state); + } + synchronize(); + + if constexpr (ReuseSmemC) { + // producer_acquire returns when at most StagesD-1 committed stores are pending + bool store_finished = issued_stores > StorePipeline::UnacquiredStages; + // Let dma warp know earliest smem buffer is consumed and empty after StagesD producer commits + if (store_finished) { + if (is_producer_load_needed) { + load_pipeline.consumer_release(load_pipe_consumer_state); + } + ++load_pipe_consumer_state; + } + } + }; + + // + // BEGIN EPILOGUE + // + + // Pre-loop fusion callback entry point + cst_callbacks.begin(); + + // For each output tile + CUTLASS_PRAGMA_UNROLL + for (int epi_n = 0; epi_n < size<3>(gD_epi); ++epi_n) { + CUTLASS_PRAGMA_UNROLL + for (int epi_m = 0; epi_m < size<2>(gD_epi); ++epi_m) { + bool is_first_iteration = epi_m == 0 && epi_n == 0; + bool is_last_iteration = epi_m == size<2>(gD_epi)-1 && epi_n == size<3>(gD_epi)-1; + + if (subtile_idx != -1 && (epi_n * static_cast(size<2>(gD_epi)) + epi_m) != subtile_idx) { + continue; + } + + cst_callbacks.begin_loop(epi_m, epi_n); + + if (is_producer_load_needed) { + // Wait for the producer load to fill smem + load_pipeline.consumer_wait(load_wait_state); + + if (is_C_load_needed) { + // Copy source tile from smem to register + copy(tiled_s2r, tSR_sC(_,_,_,load_wait_state.index()), tSR_rC); + } + } + + // First loop fusion callback entry point + cst_callbacks.previsit(epi_m, epi_n, load_wait_state.count(), is_producer_load_needed); + + if (is_producer_load_needed) { + if constexpr (not ReuseSmemC) { + // Let producer load warp know smem buffers are consumed and empty + cutlass::arch::fence_view_async_shared(); + load_pipeline.consumer_release(load_pipe_consumer_state); + ++load_pipe_consumer_state; + } + ++load_wait_state; + } + + int mma_m = epi_m; + int mma_n = (epi_n * size<1>(EpilogueTile{})) / mma_tile_n; + Tensor tRS_rAcc_frg_mn = tRS_rAcc_frg(_,mma_m,mma_n); + + // Vectorized fragment loop with visitor callback entry point + int epi_n_in_mma = epi_n % (mma_tile_n / epi_tile_n); + int r2s_v = epi_n_in_mma * size(tRS_rD_frg); + CUTLASS_PRAGMA_UNROLL + for (int epi_v = 0; epi_v < size(tRS_rD_frg); ++epi_v) { + tRS_rD_frg(epi_v) = cst_callbacks.visit(tRS_rAcc_frg_mn(r2s_v + epi_v), epi_v, epi_m, epi_n); + } + // The latest we can delay the TMA store is right before the smem store of the next iteration + // since the current TMA store needs to be committed before we can acquire the next smem buffer + if constexpr (DelayTmaStore) { + // Issue TMA stores for the previous subtile + if (not is_first_iteration and subtile_idx == -1) { + tma_store_fn(epi_m_prev, epi_n_prev); + } + epi_m_prev = epi_m; + epi_n_prev = epi_n; + } + + // Smem reduction callback entry point using current store buffer for workspace + cst_callbacks.reduce(sD_epi(_,_,store_pipe_producer_state.index()), + synchronize, epi_m, epi_n, is_last_iteration, tRS_rD_frg); + + // Copy tile from register to regiser if needed + if constexpr (IsUseR2R) { + // retile source and destination for tiled_r2r + Tensor tRR_rD_src = thread_r2r.retile_S(tRS_rD); // (R2R,R2R_M,R2R_N,EPI_M,EPI_N) + Tensor tRR_rD_dst = thread_r2r.retile_D(tRS_rD); // (R2R,R2R_M,R2R_N,EPI_M,EPI_N) + + // Output needs register shuffling before copying to shared memory. + copy(tiled_r2r, tRR_rD_src, tRR_rD_dst); + } + + // Copy tile from register to smem + if constexpr (is_destination_supported) { + copy(tiled_r2s, tRS_rD, tRS_sD(_,_,_,store_pipe_producer_state.index())); + } + + // Post reduction, pre TMA store callback entry point + constexpr bool issue_smem_store = true; // No smem store predication + cst_callbacks.postreduce(epi_m, epi_n, store_pipe_producer_state.count(), issue_smem_store); + + if constexpr (not DelayTmaStore) { + // Issue TMA stores for this subtile + tma_store_fn(epi_m, epi_n); + } + + cst_callbacks.end_loop(epi_m, epi_n); + + } // for epi_m + } // for epi_n + + + if constexpr (DelayTmaStore) { + // Issue TMA stores for the last subtile + tma_store_fn(epi_m_prev, epi_n_prev); + } + + // Post-loop fusion callback entry point + cst_callbacks.end(); + + return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state); + } + + CUTLASS_DEVICE auto + store_tail( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_consumer_state, + StorePipeline store_pipeline, + StorePipelineState store_pipe_producer_state) { + // wait for all TMA stores to complete + store_pipeline.producer_tail(store_pipe_producer_state); + // reset store counter + issued_stores = 0; + + if constexpr (ReuseSmemC) { + if (fusion_callbacks.is_producer_load_needed()) { + // Issue releases on up to StagesD-1 previously issued TMA stores + constexpr int release_stages = cute::min(StorePipeline::UnacquiredStages, get_load_pipe_increment(CtaTileMNK{})); + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < release_stages; ++stage) { + load_pipeline.consumer_release(load_pipe_consumer_state); + ++load_pipe_consumer_state; + } + } + } + + return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state); + } + + CUTLASS_DEVICE auto + store_init( + Params const& params, + TensorMapStorage& shared_tensormaps, + int32_t sm_count, + int32_t sm_idx, + int32_t warp_group_idx) { + int warp_idx_in_warp_group = canonical_warp_idx_sync() % NumWarpsPerWarpGroup; + // Since only one warp issues TMA store, we only need that one warp to initialize tensormaps + if (warp_idx_in_warp_group == 0) { + // Initialize tma + constexpr bool IsLoad = false; + auto store_tensormaps = tensormaps_init(params, shared_tensormaps, sm_count, sm_idx, warp_group_idx); + return store_tensormaps; + } + TmaDescriptor* null_tma_desc = nullptr; + return cute::make_tuple(null_tma_desc); + } + + // + // Methods to perform different parts of TMA/Tensormap modifications + // + + template + CUTLASS_DEVICE auto + tensormaps_init( + Params const& params, + TensorMapStorage& shared_tensormaps, + int32_t sm_count, + int32_t sm_idx, + int32_t warp_group_idx) { + + constexpr uint32_t NumInputTensors = NumEpilogueWarpGroups + (cute::is_void_v ? 0 : 1); + Layout desc_layout = make_layout(make_shape(sm_count, Int{})); + + Tensor gmem_tensormap = make_tensor(params.tensormaps, desc_layout); // (SMs, NumInputTensors) + + if constexpr (IsLoad) { + if (is_source_supported) { + constexpr int C_tensormap_index = NumEpilogueWarpGroups; + Tensor pC_tensormap = make_tensor(params.tma_load_c.get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sC_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_C), Int<1>{}, Int<1>{}); + + if (cute::elect_one_sync()) { + // Bringing tensormaps from params to smem for modification later + copy(recast(pC_tensormap), recast(sC_tensormap)); + } + __syncwarp(); + return cute::make_tuple(&gmem_tensormap(sm_idx, C_tensormap_index)); + + } + TmaDescriptor* null_tma_desc = nullptr; + return cute::make_tuple(null_tma_desc); + } + else { + Tensor pD_tensormap = make_tensor(params.tma_store_d.get_tma_descriptor(), Int<1>{}, Int<1>{}); + Tensor sD_tensormap = make_tensor(make_smem_ptr(&shared_tensormaps.smem_tensormap_D[warp_group_idx]), Int<1>{}, Int<1>{}); + + if (cute::elect_one_sync()) { + // Bringing tensormaps from params to smem for modification later + copy(recast(pD_tensormap), recast(sD_tensormap)); + } + __syncwarp(); + return cute::make_tuple(&gmem_tensormap(sm_idx, warp_group_idx)); + } + } + + // Replace address for the global tensor (to be done by single thread) + template + CUTLASS_DEVICE + void + tensormaps_replace_global_address( + TensorMapStorage& shared_tensormaps, + Params const& params, + int32_t next_batch, + int32_t warp_group_idx) { + // Replacing global_address for the next batch + if constexpr (IsLoad) { + if constexpr (is_source_supported) { + if (params.ptr_C != nullptr) { + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_C, + params.ptr_C[next_batch]); + } + } + } + else if constexpr (is_destination_supported) { + cute::tma_descriptor_replace_addr_in_shared_mem(shared_tensormaps.smem_tensormap_D[warp_group_idx], + params.ptr_D[next_batch]); + } + } + + // Replace dim and strides for the global tensor - used only for Grouped GEMM (to be done by single thread) + template + CUTLASS_DEVICE + void + tensormaps_replace_global_tensor_properties( + TensorMapStorage& shared_tensormaps, + Params const& params, + int32_t next_group, + ProblemShape_MNKL problem_shape_mnkl, + int32_t warp_group_idx) { + const uint32_t M = get<0>(problem_shape_mnkl); + const uint32_t N = get<1>(problem_shape_mnkl); + // Replace all dims for consistency + constexpr int MaxTensorRank = 5; + cute::array prob_shape = {1,1,1,1,1}; + cute::array prob_stride = {0,0,0,0,0}; + + if constexpr (IsLoad) { + if constexpr (is_source_supported) { + if (params.dC != nullptr) { + ElementC const* ptr_C = nullptr; + Tensor tensor_c = make_tensor(ptr_C, make_layout(make_shape(M,N,Int<1>{}), params.dC[next_group])); + + cute::detail::fill_tma_gmem_shape_stride(params.tma_load_c, tensor_c, + prob_shape, prob_stride); + // Convert strides to byte strides + for (uint64_t& stride : prob_stride) { + stride = (stride * sizeof_bits_v) / 8; + } + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_C, + prob_shape, + prob_stride); + } + } + } + else if constexpr (is_destination_supported) { + ElementD const* ptr_D = nullptr; + Tensor tensor_d = make_tensor(ptr_D, make_layout(make_shape(M,N,Int<1>{}), params.dD[next_group])); + + cute::detail::fill_tma_gmem_shape_stride(params.tma_store_d, tensor_d, + prob_shape, prob_stride); + // Convert strides to byte strides + for (uint64_t& stride : prob_stride) { + stride = (stride * sizeof_bits_v) / 8; + } + + cute::tma_descriptor_replace_dims_strides_in_shared_mem(shared_tensormaps.smem_tensormap_D[warp_group_idx], + prob_shape, + prob_stride); + } + } + + template + CUTLASS_DEVICE + void + tensormaps_perform_update( + TensorMapStorage& shared_tensormaps, + Params const& params, + cute::TmaDescriptor const* tensormap, + ProblemShape_MNKL problem_shape_mnkl, + int32_t next_batch, + int32_t warp_group_idx) { + + if (cute::elect_one_sync()) { + // Replacing global_address for the next batch + tensormaps_replace_global_address(shared_tensormaps, params, next_batch, warp_group_idx); + + if constexpr (IsGroupedGemmKernel) { + // Replacing global dims and strides for the next batch + tensormaps_replace_global_tensor_properties( + shared_tensormaps, params, next_batch, problem_shape_mnkl, warp_group_idx); + } + + } + } + + template + CUTLASS_DEVICE + void + tensormaps_cp_fence_release( + TensorMapStorage& shared_tensormaps, + cute::TmaDescriptor const* tensormap, + const int32_t warp_group_idx = 0) { + + // Entire warp must do this (ie its aligned) + if constexpr (IsLoad) { + if constexpr (is_source_supported) { + tma_descriptor_cp_fence_release(tensormap, shared_tensormaps.smem_tensormap_C); + } + } + else if constexpr (is_destination_supported) { + tma_descriptor_cp_fence_release(tensormap, shared_tensormaps.smem_tensormap_D[warp_group_idx]); + } + } + + template + CUTLASS_DEVICE + void + tensormaps_fence_acquire(cute::TmaDescriptor const* tensormap) { + if constexpr (IsLoad) { + if constexpr (is_source_supported) { + cute::tma_descriptor_fence_acquire(tensormap); + } + } + else { + cute::tma_descriptor_fence_acquire(tensormap); + } + } + +private: + Params const& params; + FusionCallbacks fusion_callbacks; + int issued_stores = 0; +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp b/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp new file mode 100644 index 0000000000..b3c7bf387d --- /dev/null +++ b/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized.hpp @@ -0,0 +1,918 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing elementwise operations used by epilogues. +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/barrier.h" +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/collective/detail.hpp" +#include "cutlass/epilogue/thread/scale_type.h" +#include "cutlass/epilogue/fusion/callbacks.hpp" +#include "cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp" +#include "cutlass/detail/collective.hpp" +#include "cutlass/detail/layout.hpp" +#include "cutlass/trace.h" + +#include "cute/tensor.hpp" +#include "cutlass/cuda_host_adapter.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + int StagesC_, + int StagesD_, + int FragmentSize_, + bool ReuseSmemC_, + bool DelayTmaStore_, + class CtaTileMNK_, // (CTA_M,CTA_N,CTA_K) + class EpilogueTile_, // (EPI_TILE_M,EPI_TILE_N) + class ElementC_, + class StrideC_, + class ElementD_, + class StrideD_, + class FusionCallbacks_, + class CopyOpG2S_, + class SmemLayoutAtomC_, + class CopyOpS2R_, + class CopyOpS2G_, + class SmemLayoutAtomD_, + class CopyOpR2S_, + class CopyAtomC_, + class CopyOpR2R_ +> +class CollectiveEpilogue< + Sm90TmaWarpSpecialized, + CtaTileMNK_, + EpilogueTile_, + ElementC_, + StrideC_, + ElementD_, + StrideD_, + FusionCallbacks_, + CopyOpG2S_, + SmemLayoutAtomC_, + CopyOpS2R_, + CopyOpS2G_, + SmemLayoutAtomD_, + CopyOpR2S_, + CopyAtomC_, + CopyOpR2R_ +> { +public: + // + // Type Aliases + // + using DispatchPolicy = Sm90TmaWarpSpecialized; + using CtaTileMNK = CtaTileMNK_; + using EpilogueTile = EpilogueTile_; + using FusionCallbacks = FusionCallbacks_; + using ElementC = ElementC_; + using StrideC = StrideC_; + using ElementD = ElementD_; + using StrideD = StrideD_; + using CopyOpG2S = CopyOpG2S_; + using SmemLayoutAtomC = SmemLayoutAtomC_; + using CopyOpS2R = CopyOpS2R_; + using CopyOpS2G = CopyOpS2G_; + using SmemLayoutAtomD = SmemLayoutAtomD_; + using CopyOpR2S = CopyOpR2S_; + using CopyAtomC = CopyAtomC_; + using CopyOpR2R = CopyOpR2R_; + + using ThreadEpilogueOp = typename epilogue::fusion::FusionCallbacksTraits::Operation; + using GmemTiledCopyC = CopyOpG2S; + using GmemTiledCopyD = CopyOpS2G; + + static_assert(!is_layout::value && is_tuple::value, "EpilogueTile must be a cute::Tile or cute::Shape"); + static_assert(cute::rank(CtaTileMNK{}) == 3, "CtaTileMNK must be rank-3: [CTA_M, CTA_N, CTA_K]"); + static_assert(cute::rank(EpilogueTile{}) == 2, "EpilogueTile must be rank-2: [EPI_TILE_M, EPI_TILE_N]"); + static_assert(size<0>(CtaTileMNK{}) % size<0>(shape(EpilogueTile{})) == 0, "EPI_TILE_M must divide CTA_M"); + static_assert(size<1>(CtaTileMNK{}) % size<1>(shape(EpilogueTile{})) == 0, "EPI_TILE_N must divide CTA_N"); + static_assert(cute::rank(StrideC{}) == 3, "StrideC must be rank-3: [M, N, L]"); + static_assert(cute::rank(StrideD{}) == 3, "StrideD must be rank-3: [M, N, L]"); + +private: + constexpr static bool is_source_supported = not cute::is_void_v; + constexpr static bool is_destination_supported = not cute::is_void_v; + using NonVoidElementD = cute::conditional_t, ElementD>; + static_assert(not cute::is_void_v, "SmemElementD is void"); + using NonVoidElementC = cute::conditional_t; // prevents void ref breakages + + using TmaElementD = cute::conditional_t>, uint64_t, NonVoidElementD>; + using TmaElementC = cute::conditional_t>, uint64_t, NonVoidElementC>; + + using SmemElementC = typename cutlass::detail::get_unpacked_element_type::type; + using SmemElementD = typename cutlass::detail::get_unpacked_element_type::type; + + constexpr static int StagesC = StagesC_; + constexpr static int StagesD = StagesD_; + constexpr static bool ReuseSmemC = ReuseSmemC_ and is_destination_supported; + constexpr static bool DelayTmaStore = DelayTmaStore_; + + constexpr static bool is_m_major_C = detail::is_m_major(); + constexpr static bool is_m_major_D = detail::is_m_major(); + + constexpr static bool is_im2col_C = cute::is_same_v; + constexpr static bool is_im2col_D = cute::is_same_v; + + // Check if register transformation is needed before copying register to shared memory. + constexpr static bool IsUseR2R = !cute::is_void_v; + + using SmemLayoutC = decltype(tile_to_shape( + SmemLayoutAtomC{}, + make_shape(size<0>(EpilogueTile{}), size<1>(EpilogueTile{}), Int{}), + cute::conditional_t, Step<_1,_2,_3>>{} )); + using SmemLayoutD = decltype(tile_to_shape( + SmemLayoutAtomD{}, + make_shape(size<0>(EpilogueTile{}), size<1>(EpilogueTile{}), Int{}), + cute::conditional_t, Step<_1,_2,_3>>{} )); + + constexpr static bool support_smem_reuse = is_source_supported && is_destination_supported && StagesD <= StagesC + && cosize(take<0,2>(SmemLayoutC{})) == cosize(take<0,2>(SmemLayoutD{})); + static_assert(not (ReuseSmemC && not support_smem_reuse), "Smem reuse requirements not met"); + + constexpr static size_t SmemAlignmentD = cutlass::detail::alignment_for_swizzle(SmemLayoutD{}); + constexpr static size_t SmemAlignmentC = cutlass::detail::alignment_for_swizzle(SmemLayoutC{}); + constexpr static size_t MaxSmemAlignment = cute::max(SmemAlignmentC, SmemAlignmentD); + + using SmemArrayTypeC = cute::ArrayEngine>; + using SmemArrayTypeD = cute::ArrayEngine>; + + using EmptyType = cute::tuple<>; + using SmemCStorage = cute::conditional_t; + using SmemDStorage = cute::conditional_t; + + struct CollectiveStorageWithC { + alignas(SmemAlignmentC) ArrayEngine> smem_C; + alignas(SmemAlignmentD) ArrayEngine> smem_D; + }; + + union CollectiveStorageWithoutC { + cute::array smem_C; + alignas(SmemAlignmentD) ArrayEngine> smem_D; + }; + + union CollectiveStorageReuseC { + alignas(MaxSmemAlignment) ArrayEngine> smem_C; + alignas(MaxSmemAlignment) ArrayEngine> smem_D; + }; + +public: + // TMA pipeline for loading C + using LoadPipeline = cutlass::PipelineTransactionAsync; + using LoadPipelineState = cutlass::PipelineState; + constexpr static uint32_t TmaTransactionBytes = + (size(take<0,2>(SmemLayoutC{})) * static_cast(sizeof_bits::value)) / 8; + constexpr static bool RequiresTransactionBytes = true; + + // TMA pipeline for storing D + using StorePipeline = cute::conditional_t, + cutlass::PipelineTmaStore>; + using StorePipelineState = cutlass::PipelineState; + + struct SharedStorage { + struct TensorStorage { + using CollectiveStorage = cute::conditional_t>; + CollectiveStorage collective; + + using FusionStorage = typename FusionCallbacks::SharedStorage; + FusionStorage thread; + } tensors; + + using PipelineStorage = typename LoadPipeline::SharedStorage; + PipelineStorage pipeline; + }; + using TensorStorage = typename SharedStorage::TensorStorage; + using PipelineStorage = typename SharedStorage::PipelineStorage; + + // Host side epilogue arguments + struct Arguments { + typename FusionCallbacks::Arguments thread{}; + ElementC const* ptr_C; + StrideC dC; + ElementD const* ptr_D; + StrideD dD; + }; + + // Device side epilogue params + struct Params { + using TMA_C = decltype(make_tma_copy( + CopyOpG2S{}, + make_tensor(make_gmem_ptr(nullptr), + repeat_like(StrideC{}, int32_t(0)), StrideC{}), + take<0,2>(SmemLayoutC{}), + EpilogueTile{}, + _1{})); + using TMA_D = decltype(make_tma_copy( + CopyOpS2G{}, + make_tensor(make_gmem_ptr(nullptr), + repeat_like(StrideD{}, int32_t(0)), StrideD{}), + take<0,2>(SmemLayoutD{}), + EpilogueTile{}, + _1{})); + + typename FusionCallbacks::Params thread{}; + TMA_C tma_load_c; + TMA_D tma_store_d; + uint32_t tma_transaction_bytes = TmaTransactionBytes; + }; + + // + // Methods + // + + template + static constexpr Params + to_underlying_arguments( + ProblemShape const& problem_shape, + Arguments const& args, + [[maybe_unused]] void* workspace) { + // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_MNKL; + + uint32_t transaction_bytes = TmaTransactionBytes; + typename Params::TMA_C tma_load_c{}; + if constexpr (is_source_supported) { + Tensor tensor_c = make_tensor(make_gmem_ptr(args.ptr_C), make_layout(make_shape(M,N,L), args.dC)); + tma_load_c = make_tma_copy_C_sm90( + CopyOpG2S{}, + tensor_c, + take<0,2>(SmemLayoutC{}), + EpilogueTile{}); + } + + typename Params::TMA_D tma_store_d; + if constexpr (is_destination_supported) { + Tensor tensor_d = make_tensor(make_gmem_ptr(args.ptr_D), make_layout(make_shape(M,N,L), args.dD)); + tma_store_d = make_tma_copy_C_sm90( + CopyOpS2G{}, + tensor_d, + take<0,2>(SmemLayoutD{}), + EpilogueTile{}); + } + + return { + FusionCallbacks::to_underlying_arguments(problem_shape, args.thread, workspace), + tma_load_c, + tma_store_d, + transaction_bytes + }; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return FusionCallbacks::get_workspace_size(problem_shape, args.thread); + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return FusionCallbacks::initialize_workspace(problem_shape, args.thread, workspace, stream, cuda_adapter); + } + + template + static bool + can_implement( + ProblemShape const& problem_shape, + [[maybe_unused]] Arguments const& args) { + auto problem_shape_MNKL = append<4>(problem_shape, 1); + auto [M,N,K,L] = problem_shape_MNKL; + auto shape = cute::make_shape(M,N,L); + + bool implementable = true; + if constexpr (is_destination_supported) { + constexpr int tma_alignment_bits_D = cutlass::detail::get_output_alignment_bits(); + constexpr int min_tma_aligned_elements_D = tma_alignment_bits_D / cutlass::sizeof_bits::value; + if constexpr (cute::is_same_v) { // ignore L stride for implicit gemm + implementable = cutlass::detail::check_alignment(take<0,2>(shape), take<0,2>(StrideD{})); + } + else { + implementable = cutlass::detail::check_alignment(shape, StrideD{}); + } + } + + if constexpr (not cute::is_void_v) { + constexpr int tma_alignment_bits_C = cutlass::detail::get_input_alignment_bits(); + constexpr int min_tma_aligned_elements_C = tma_alignment_bits_C / cutlass::sizeof_bits::value; + if constexpr (cute::is_same_v) { // ignore L stride for implicit gemm + implementable = implementable && cutlass::detail::check_alignment(take<0,2>(shape), take<0,2>(StrideC{})); + } + else { + implementable = implementable && cutlass::detail::check_alignment(shape, StrideC{}); + } + } + + if (!implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum alignment requirements for TMA.\n"); + } + + bool fusion_implementable = FusionCallbacks::can_implement(problem_shape, args.thread); + + if (!fusion_implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Problem Size doesn't meet the minimum requirements for FusionCallbacks.\n"); + } + + bool beta_implementable = true; + + if constexpr (cute::is_void_v) { + if constexpr (detail::has_beta::value) { + beta_implementable = args.thread.beta == 0.0; + } + if constexpr (detail::has_beta_ptr::value) { + beta_implementable = beta_implementable && args.thread.beta_ptr == nullptr; + } + } + + if (!beta_implementable) { + CUTLASS_TRACE_HOST(" CAN IMPLEMENT: Beta/beta pointer was set, but epilogue is sourceless (void-C).\n"); + } + + return implementable && fusion_implementable && beta_implementable; + } + + template + CUTLASS_HOST_DEVICE + static constexpr int + get_load_pipe_increment(TileShapeMNK tile_shape_MNK) { + // Compute number of epilogue subtiles + return size<1>(zipped_divide(make_layout(take<0,2>(tile_shape_MNK)), EpilogueTile{})); + } + + template + CUTLASS_HOST_DEVICE + static constexpr int + get_store_pipe_increment(TileShapeMNK tile_shape_MNK) { + return get_load_pipe_increment(tile_shape_MNK); + } + + /// Issue Tma Descriptor Prefetch -- ideally from a single thread for best performance + CUTLASS_DEVICE + static void + prefetch_tma_descriptors(Params const& epilogue_params) { + if constexpr (is_source_supported) { + cute::prefetch_tma_descriptor(epilogue_params.tma_load_c.get_tma_descriptor()); + } + if constexpr (is_destination_supported) { + cute::prefetch_tma_descriptor(epilogue_params.tma_store_d.get_tma_descriptor()); + } + } + + CUTLASS_HOST_DEVICE + CollectiveEpilogue(Params const& params_, TensorStorage& shared_tensors) + : params(params_), fusion_callbacks(params_.thread, shared_tensors.thread) {} + + CUTLASS_DEVICE + bool + is_producer_load_needed() const { + return fusion_callbacks.is_producer_load_needed(); + } + + template< + class ProblemShapeMNKL, + class TileShapeMNK, + class TileCoordMNKL, + class TiledMma + > + CUTLASS_DEVICE auto + load( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_producer_state, + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_MNK, + TileCoordMNKL tile_coord_mnkl, + TiledMma tiled_mma, + int thread_idx, + TensorStorage& shared_tensors, + int subtile_idx=-1) { + using namespace cute; + + // Indexing variables + auto [M, N, K, L] = problem_shape_mnkl; + auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl; + + // The tma tensor C under im2col mode only has two modes (M, N) which + // should be local tiled with only (m_coord, n_coord). + auto coord_shape = conditional_return( + make_coord(m_coord, n_coord), + make_coord(m_coord, n_coord, l_coord)); + + // Represent the full source tensor, slice to get the tile this CTA is currently responsible for + Tensor mC_mn = params.tma_load_c.get_tma_tensor(make_shape(M,N,L)); // (M,N,L) + Tensor mC = coalesce(mC_mn, take<0,2>(CtaTileMNK{})); + Tensor gC = local_tile(mC, take<0,2>(CtaTileMNK{}), coord_shape); // (CTA_M,CTA_N) + + // Apply epilogue subtile, get matching smem tensor + auto ptr_sC = shared_tensors.collective.smem_C.begin(); + Tensor gC_epi = flat_divide(gC, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + Tensor sC_epi = make_tensor(make_smem_ptr(ptr_sC), SmemLayoutC{}); // (EPI_TILE_M,EPI_TILE_N,PIPE_C) + + // Prepare the thread(b)lock's (G)mem to (S)mem TMA tiled copy (bGS_) + ThrCopy thrblk_g2s = params.tma_load_c.get_slice(Int<0>{}); + Tensor bGS_gC = thrblk_g2s.partition_S(gC_epi); // (G2S,G2S_M,G2S_N,EPI_M,EPI_N) + Tensor bGS_sC = thrblk_g2s.partition_D(sC_epi); // (G2S,G2S_M,G2S_N,PIPE_C) + + // Get the fusion callbacks for the producer load warp + auto pld_args = cutlass::epilogue::fusion::detail::ProducerLoadArgs( + problem_shape_mnkl, + CtaTileMNK{}, + tile_coord_mnkl, + tiled_mma, + EpilogueTile{}, + thread_idx + ); + auto pld_callbacks = fusion_callbacks.get_producer_load_callbacks(pld_args); + bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); + + // Predication for TMA load (one thread issues TMA load) + bool issue_tma_load = cute::elect_one_sync(); + + // Pre-loop fusion callback entry point + pld_callbacks.begin(); + + CUTLASS_PRAGMA_UNROLL + for (int epi_n = 0; epi_n < size<3>(gC_epi); ++epi_n) { + CUTLASS_PRAGMA_UNROLL + for (int epi_m = 0; epi_m < size<2>(gC_epi); ++epi_m) { + if (subtile_idx != -1 && (epi_n * static_cast(size<2>(gC_epi)) + epi_m) != subtile_idx) { + continue; + } + // Acquire the lock for this stage + constexpr uint16_t mcast_mask = 0; + uint64_t* tma_barrier = load_pipeline.producer_get_barrier(load_pipe_producer_state); + load_pipeline.producer_acquire(load_pipe_producer_state); + + // Loop fusion callback entry point + pld_callbacks.step(tma_barrier, epi_m, epi_n, load_pipe_producer_state.count(), issue_tma_load); + + // Execute the TMA load for C if needed + if (issue_tma_load && is_C_load_needed) { + copy(params.tma_load_c.with(*tma_barrier, mcast_mask), + bGS_gC(_,_,_,epi_m,epi_n), bGS_sC(_,_,_,load_pipe_producer_state.index())); + load_pipeline.producer_expect_transaction(load_pipe_producer_state); + } + + // Commit TMA loads for this stage and release the lock + load_pipeline.producer_commit(load_pipe_producer_state); + ++load_pipe_producer_state; + } + } + + // Post-loop fusion callback entry point + pld_callbacks.end(); + + return load_pipe_producer_state; + } + + CUTLASS_DEVICE auto + load_tail( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_producer_state) { + bool issue_tma_load = cute::elect_one_sync(); + if (issue_tma_load) { + load_pipeline.producer_tail(load_pipe_producer_state); + } + + return load_pipe_producer_state; + } + + template< + class ProblemShapeMNKL, + class TileShapeMNK, + class TileCoordMNKL, + class AccEngine, class AccLayout, + class TiledMma + > + CUTLASS_DEVICE auto + store( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_consumer_state, + StorePipeline store_pipeline, + StorePipelineState store_pipe_producer_state, + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_MNK, + TileCoordMNKL tile_coord_mnkl, + cute::Tensor accumulators, + TiledMma tiled_mma, + int thread_idx, + TensorStorage& shared_tensors, + int subtile_idx=-1) { + using namespace cute; + using ElementAccumulator = typename AccEngine::value_type; + using ElementCompute_ = typename epilogue::fusion::FusionCallbacksTraits::ElementCompute; + using ElementCompute = cute::conditional_t,ElementAccumulator,ElementCompute_>; + + static_assert(is_rmem::value, "Accumulator must be RF resident."); + static_assert(rank(AccLayout{}) == 3, "Accumulator must be MMA-partitioned: (MMA,MMA_M,MMA_N)"); + static_assert(rank(ProblemShapeMNKL{}) == 4, "ProblemShapeMNKL must be rank 4"); + static_assert(is_static::value, "TileShapeMNK must be static"); + static_assert(rank(TileShapeMNK{}) == 3, "TileShapeMNK must be rank 3"); + static_assert(rank(TileCoordMNKL{}) == 4, "TileCoordMNKL must be rank 4"); + + // Indexing variables + auto [M, N, K, L] = problem_shape_mnkl; + auto [m_coord, n_coord, k_coord, l_coord] = tile_coord_mnkl; + + // The tma tensor D under im2col mode only has two modes (M, N) which + // should be local tiled with only (m_coord, n_coord). + auto coord_shape = conditional_return( + make_coord(m_coord, n_coord), + make_coord(m_coord, n_coord, l_coord)); + + // Represent the full output tensor, slice to get the tile this CTA is responsible for + Tensor mD_mn = params.tma_store_d.get_tma_tensor(make_shape(M,N,L)); // (M,N,L) + Tensor mD = coalesce(mD_mn, take<0,2>(CtaTileMNK{})); + Tensor gD = local_tile(mD, take<0,2>(CtaTileMNK{}), coord_shape); // (CTA_M,CTA_N) + + // Apply epilogue subtiling + Tensor gD_epi = flat_divide(gD, EpilogueTile{}); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + + // Construct the corresponding pipelined smem tensors + auto ptr_sC = shared_tensors.collective.smem_C.begin(); + auto ptr_sD = shared_tensors.collective.smem_D.begin(); + Tensor sC_epi = cute::as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(ptr_sC), SmemLayoutC{})); // (EPI_TILE_M,EPI_TILE_N,PIPE_C) + Tensor sD_epi = cute::as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(ptr_sD), SmemLayoutD{})); // (EPI_TILE_M,EPI_TILE_N,PIPE_D) + + TiledCopy tiled_copy_C_atom = make_tiled_copy_C_atom(CopyAtomC{}, tiled_mma); + + // (t)hread-partition for (r)egister to (r)egister copy (tRR_) + TiledCopy tiled_r2r = [&]() { + if constexpr (IsUseR2R) { + return make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); + } + else { + return make_tiled_copy_S(Copy_Atom, + ElementCompute>{}, tiled_copy_C_atom); + } + }(); + ThrCopy thread_r2r = tiled_r2r.get_slice(thread_idx); + + // (t)hread-partition for (r)egister to (s)mem copy (tRS_) + TiledCopy tiled_r2s = [&]() { + if constexpr (IsUseR2R) { + return make_tiled_copy_D(Copy_Atom{}, tiled_r2r); + } + else { + return make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); + } + }(); + ThrCopy thread_r2s = tiled_r2s.get_slice(thread_idx); + Tensor tRS_rAcc = thread_r2s.retile_S(accumulators); // ((R2S,R2S_V),MMA_M,MMA_N) + Tensor tRS_sD = thread_r2s.partition_D(sD_epi); // (R2S,R2S_M,R2S_N,PIPE_D) + + auto mma_tile_m = size<0>(TileShapeMNK{}) / size<1>(tRS_rAcc); + auto mma_tile_n = size<1>(TileShapeMNK{}) / size<2>(tRS_rAcc); + auto epi_tile_m = size<0>(EpilogueTile{}); + auto epi_tile_n = size<1>(EpilogueTile{}); + + // Allocate D registers + Layout tRS_rD_layout = make_layout(take<0,3>(shape(thread_r2s.partition_S(sD_epi)))); + Tensor tRS_rD = make_tensor(tRS_rD_layout); // (R2S,R2S_M,R2S_N) + + // Vectorized fragment view + constexpr int FragmentSize = DispatchPolicy::FragmentSize; + Tensor tRS_rAcc_frg = recast>(tRS_rAcc); + Tensor tRS_rD_frg = recast>(tRS_rD); + CUTE_STATIC_ASSERT(size<0>(tRS_rAcc) % FragmentSize == 0, "Fragment size does not vectorize properly"); + + // (t)hread-partition for (s)mem to (r)egister copy (tSR_) + TiledCopy tiled_s2r = make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); + ThrCopy thread_s2r = tiled_s2r.get_slice(thread_idx); + Tensor tSR_sC = thread_s2r.partition_S(sC_epi); // (S2R,S2R_M,S2R_N,PIPE_C) + Layout tSR_rC_layout = thread_s2r.retile_D(tRS_rD).layout(); // (S2R,S2R_M,S2R_N) + + // Allocate C registers + // If C smem load is a non-vectorized dst(i) = src(i) then we can allocate C registers directly in the compute type + // to eliminate some redundant pack+unpack instruction sequences for sub-word types + constexpr bool IsDirectS2R = cute::is_same_v> + && decltype(max_common_vector(tSR_rC_layout, tSR_sC.layout()))::value <= 1; + using RegisterElementC = cute::conditional_t; + Tensor tRS_rC = make_tensor(tRS_rD_layout); // (R2S,R2S_M,R2S_N) + Tensor tSR_rC = thread_s2r.retile_D(tRS_rC); // (S2R,S2R_M,S2R_N) + + // thread(b)lock-partition for (s)mem to (g)mem copy (bSG_) + ThrCopy thrblk_s2g = params.tma_store_d.get_slice(Int<0>{}); + Tensor bSG_sD = thrblk_s2g.partition_S(sD_epi); // (S2G,S2G_M,S2G_N,PIPE_D) + Tensor bSG_gD = thrblk_s2g.partition_D(gD_epi); // (S2G,S2G_M,S2G_N,EPI_M,EPI_N) + + // OOB predication for tile quantization "residue" + // Absolute coordinate tensors (dynamic) + Tensor mD_crd = make_identity_tensor(make_shape(M,N)); // (M,N) + Tensor cD_mn = local_tile(mD_crd, take<0,2>(CtaTileMNK{}), make_coord(m_coord, n_coord)); // (CTA_M,CTA_N) + Tensor tRS_cD_mn = [&]() { + if constexpr (IsUseR2R) { + // (t)hread-partition for ConsumerStoreCallbacks. + TiledCopy tiled_cst = make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); + ThrCopy thread_cst = tiled_cst.get_slice(thread_idx); + + return thread_cst.partition_S(flat_divide(cD_mn, EpilogueTile{})); // (R2S,R2S_M,R2S_N,EPI_M,EPI_N) + } + else { + return thread_r2s.partition_S(flat_divide(cD_mn, EpilogueTile{})); // (R2S,R2S_M,R2S_N,EPI_M,EPI_N) + } + }(); + // Relative coordinate tensors (static) + Tensor cD = make_counting_tensor(cD_mn.layout()); // (CTA_M,CTA_N) + Tensor tRS_cD = make_counting_tensor(tRS_cD_mn.layout()); // (R2S,R2S_M,R2S_N,EPI_M,EPI_N) + // Subtract the global "bottom right" corner from the local "top left" corner to get the max relative coordinate + auto residue_cD = make_coord(M,N) - cD_mn(_0{}); // (m,n) + auto residue_tRS_cD = make_coord(M,N) - tRS_cD_mn(_0{}); // (m,n) + + CUTE_STATIC_ASSERT(epi_tile_m % mma_tile_m == 0, "MMA_TILE_M must divide EPI_TILE_M"); + + CUTE_STATIC_ASSERT(mma_tile_n % epi_tile_n == 0, "EPI_TILE_N must divide MMA_TILE_N"); + // Get TiledCopy for partition reference when consumer store. + TiledCopy tiled_copy_partition_ref = make_tiled_copy_S(Copy_Atom{}, tiled_copy_C_atom); + // Get the fusion callbacks for the consumer store warps + constexpr bool RefSrc = true; // Register tensors reference tiled copy src layout + auto cst_args = cutlass::epilogue::fusion::detail::ConsumerStoreArgs( + problem_shape_mnkl, + CtaTileMNK{}, + tile_coord_mnkl, + tiled_mma, + EpilogueTile{}, + tiled_copy_partition_ref, + cD, + residue_cD, + tRS_cD, + residue_tRS_cD, + tRS_rC, + thread_idx + ); + auto cst_callbacks = fusion_callbacks.template get_consumer_store_callbacks(cst_args); + bool is_producer_load_needed = fusion_callbacks.is_producer_load_needed(); + bool is_C_load_needed = is_source_supported && fusion_callbacks.is_C_load_needed(); + + using FragmentVisit = decltype(cst_callbacks.visit(tRS_rAcc_frg(0), 0, 0, 0)); + constexpr bool IsDirectR2S = cute::is_same_v>; + using RegisterElementD = cute::conditional_t; + Tensor tRS_rCompute = make_tensor(tRS_rD_layout); // (R2S,R2S_M,R2S_N) + Tensor tRS_rCompute_frg = recast>(tRS_rCompute); + + // Thread synchronizer for previously issued waits or fences + // to ensure visibility of smem reads/writes to threads or TMA unit + auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(size(TiledMma{}), cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; + + // Predication for TMA store (one warp issues TMA store) + bool issue_tma_store = (thread_idx / NumThreadsPerWarp) == 0; + + // In the reuse smem configuration we have StagesC smem buffers and at most StagesD committed TMA stores in flight. + // The TMA store pipeline producer acquire returns when at most StagesD-1 committed stores are in-flight, so we can + // only guarantee store completion after StagesD iterations, then we can begin issuing releases on the smem buffer locks. + // store_pipe_producer_state tracks the acquire and load_pipe_consumer_state tracks the release, in circular buffer fashion. + LoadPipelineState load_wait_state = load_pipe_consumer_state; + if constexpr (ReuseSmemC) { + load_wait_state = store_pipe_producer_state; + load_wait_state.phase_ ^= 1; + } + + // We can delay issue of TMA store by one iteration to achieve better interleaving of non-TMA instructions + // Sync requirements of smem reuse may preclude this optimization + // Delayed stores cause delayed stage releases which causes deadlock when StagesC == StagesD + [[maybe_unused]] int epi_m_prev = 0; + [[maybe_unused]] int epi_n_prev = 0; + static_assert(not (DelayTmaStore and ReuseSmemC and StagesC <= StagesD), "This TMA epilogue configuration will deadlock"); + + // The TMA store sequence for one subtile iteration + auto tma_store_fn = [&] (int epi_m, int epi_n) { + // Write the tile from smem to gmem with TMA + cutlass::arch::fence_view_async_shared(); // ensure smem writes are visible to TMA + synchronize(); // ensure all threads have issued their async fence + if constexpr (is_destination_supported) { + if (issue_tma_store) { + copy(params.tma_store_d, bSG_sD(_,_,_,store_pipe_producer_state.index()), bSG_gD(_,_,_,epi_m,epi_n)); + } + } + + // Post async fence, pre TMA commit callback entry point + cst_callbacks.tma_store(epi_m, epi_n, store_pipe_producer_state.count(), issue_tma_store); + + // Commit the TMA stores for this stage + if (issue_tma_store) { + store_pipeline.producer_commit(store_pipe_producer_state); + } + ++store_pipe_producer_state; + ++issued_stores; + + // Wait for the next smem buffer to be available + if (issue_tma_store) { + store_pipeline.producer_acquire(store_pipe_producer_state); + } + synchronize(); + + if constexpr (ReuseSmemC) { + // producer_acquire returns when at most StagesD-1 committed stores are pending + bool store_finished = issued_stores > StorePipeline::UnacquiredStages; + // Let dma warp know earliest smem buffer is consumed and empty after StagesD producer commits + if (store_finished) { + if (is_producer_load_needed) { + load_pipeline.consumer_release(load_pipe_consumer_state); + } + ++load_pipe_consumer_state; + } + } + }; + + // + // BEGIN EPILOGUE + // + + // Pre-loop fusion callback entry point + cst_callbacks.begin(); + + // For each output tile + CUTLASS_PRAGMA_UNROLL + for (int epi_n = 0; epi_n < size<3>(gD_epi); ++epi_n) { + CUTLASS_PRAGMA_UNROLL + for (int epi_m = 0; epi_m < size<2>(gD_epi); ++epi_m) { + [[maybe_unused]] bool is_first_iteration = epi_m == 0 && epi_n == 0; + bool is_last_iteration = epi_m == size<2>(gD_epi)-1 && epi_n == size<3>(gD_epi)-1; + + if (subtile_idx != -1 && (epi_n * static_cast(size<2>(gD_epi)) + epi_m) != subtile_idx) { + continue; + } + + cst_callbacks.begin_loop(epi_m, epi_n); + + if (is_producer_load_needed) { + // Wait for the producer load to fill smem + load_pipeline.consumer_wait(load_wait_state); + + if (is_C_load_needed) { + // Copy source tile from smem to register + copy(tiled_s2r, tSR_sC(_,_,_,load_wait_state.index()), tSR_rC); + } + } + + // First loop fusion callback entry point + cst_callbacks.previsit(epi_m, epi_n, load_wait_state.count(), is_producer_load_needed); + + if (is_producer_load_needed) { + if constexpr (not ReuseSmemC) { + // Let producer load warp know smem buffers are consumed and empty + cutlass::arch::fence_view_async_shared(); + load_pipeline.consumer_release(load_pipe_consumer_state); + ++load_pipe_consumer_state; + } + ++load_wait_state; + } + + int mma_m = epi_m; + int mma_n = (epi_n * size<1>(EpilogueTile{})) / mma_tile_n; + Tensor tRS_rAcc_frg_mn = tRS_rAcc_frg(_,mma_m,mma_n); + + // Vectorized fragment loop with visitor callback entry point + int epi_n_in_mma = epi_n % (mma_tile_n / epi_tile_n); + int r2s_v = epi_n_in_mma * size(tRS_rCompute_frg); + CUTLASS_PRAGMA_UNROLL + for (int epi_v = 0; epi_v < size(tRS_rCompute_frg); ++epi_v) { + tRS_rCompute_frg(epi_v) = cst_callbacks.visit(tRS_rAcc_frg_mn(r2s_v + epi_v), epi_v, epi_m, epi_n); + } + // The latest we can delay the TMA store is right before the smem store of the next iteration + // since the current TMA store needs to be committed before we can acquire the next smem buffer + if constexpr (DelayTmaStore) { + // Issue TMA stores for the previous subtile + if (not is_first_iteration and subtile_idx == -1) { + tma_store_fn(epi_m_prev, epi_n_prev); + } + epi_m_prev = epi_m; + epi_n_prev = epi_n; + } + + // Smem reduction callback entry point using current store buffer for workspace + cst_callbacks.reduce(sD_epi(_,_,store_pipe_producer_state.index()), + synchronize, epi_m, epi_n, is_last_iteration, tRS_rCompute_frg); + + // Copy tile from register to regiser if needed + if constexpr (IsUseR2R) { + // retile source and destination for tiled_r2r + Tensor tRR_rD_src = thread_r2r.retile_S(tRS_rCompute); // (R2R,R2R_M,R2R_N,EPI_M,EPI_N) + Tensor tRR_rD_dst = thread_r2r.retile_D(tRS_rCompute); // (R2R,R2R_M,R2R_N,EPI_M,EPI_N) + + // Output register transformation before copying to shared memory. + copy(tiled_r2r, tRR_rD_src, tRR_rD_dst); + } + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tRS_rD_frg); ++i) { + tRS_rD_frg(i) = cutlass::NumericArrayConverter{}(tRS_rCompute_frg(i)); + } + + // Copy tile from register to smem + if constexpr (is_destination_supported) { + copy(tiled_r2s, tRS_rD, tRS_sD(_,_,_,store_pipe_producer_state.index())); + } + + // Post reduction, pre TMA store callback entry point + constexpr bool issue_smem_store = true; // No smem store predication + cst_callbacks.postreduce(epi_m, epi_n, store_pipe_producer_state.count(), issue_smem_store); + + if constexpr (not DelayTmaStore) { + // Issue TMA stores for this subtile + tma_store_fn(epi_m, epi_n); + } + + cst_callbacks.end_loop(epi_m, epi_n); + + } // for epi_m + } // for epi_n + + if constexpr (DelayTmaStore) { + // Issue TMA stores for the last subtile + tma_store_fn(epi_m_prev, epi_n_prev); + } + + // Post-loop fusion callback entry point + cst_callbacks.end(); + + return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state); + } + + CUTLASS_DEVICE auto + store_tail( + LoadPipeline load_pipeline, + LoadPipelineState load_pipe_consumer_state, + StorePipeline store_pipeline, + StorePipelineState store_pipe_producer_state) { + // wait for all TMA stores to complete + store_pipeline.producer_tail(store_pipe_producer_state); + // reset store counter + issued_stores = 0; + + if constexpr (ReuseSmemC) { + if (fusion_callbacks.is_producer_load_needed()) { + // Issue releases on up to StagesD-1 previously issued TMA stores + constexpr int release_stages = cute::min(StorePipeline::UnacquiredStages, get_load_pipe_increment(CtaTileMNK{})); + CUTLASS_PRAGMA_UNROLL + for (int stage = 0; stage < release_stages; ++stage) { + load_pipeline.consumer_release(load_pipe_consumer_state); + ++load_pipe_consumer_state; + } + } + } + + return cute::make_tuple(load_pipe_consumer_state, store_pipe_producer_state); + } + +private: + Params const& params; + FusionCallbacks fusion_callbacks; + int issued_stores = 0; +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized_bias_elementwise.hpp b/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized_bias_elementwise.hpp new file mode 100644 index 0000000000..9749040081 --- /dev/null +++ b/include/cutlass/epilogue/collective/sm90_epilogue_tma_warpspecialized_bias_elementwise.hpp @@ -0,0 +1,164 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +/*! \file + \brief Functor performing pipelined epilogues with bias add and elementwise activation functions. + This collective is now DEPRECATED, will be removed in the next release. Use EVT instead. +*/ + +#pragma once + +#include "sm90_epilogue_tma_warpspecialized.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass { +namespace epilogue { +namespace collective { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + int StagesC_, + int StagesD_, + int FragmentSize_, + class BlockTileShape_, // (BLK_M,BLK_N,BLK_K) + class EpilogueTileShape_, // (EPI_TILE_M,EPI_TILE_N) + class ElementC_, + class StrideC_, + class ElementD_, + class StrideD_, + class FusionCallbacks_, + class CopyOpG2S_, + class SmemLayoutAtomC_, + class CopyOpS2R_, + class CopyOpS2G_, + class SmemLayoutAtomD_, + class CopyOpR2S_, + class CopyAtomC_, + class CopyOpR2R_ +> +class Sm90EpilogueTmaWarpSpecializedBiasElementwise + : public CollectiveEpilogue< + Sm90TmaWarpSpecialized, + BlockTileShape_, + EpilogueTileShape_, + ElementC_, + StrideC_, + ElementD_, + StrideD_, + FusionCallbacks_, + CopyOpG2S_, + SmemLayoutAtomC_, + CopyOpS2R_, + CopyOpS2G_, + SmemLayoutAtomD_, + CopyOpR2S_, + CopyAtomC_, + CopyOpR2R_ +> { +private: + using Impl = + CollectiveEpilogue< + Sm90TmaWarpSpecialized, + BlockTileShape_, + EpilogueTileShape_, + ElementC_, + StrideC_, + ElementD_, + StrideD_, + FusionCallbacks_, + CopyOpG2S_, + SmemLayoutAtomC_, + CopyOpS2R_, + CopyOpS2G_, + SmemLayoutAtomD_, + CopyOpR2S_, + CopyAtomC_, + CopyOpR2R_ + >; +public: + using DispatchPolicy = Sm90TmaWarpSpecializedBiasElementwise; + using ElementCompute = typename Impl::ThreadEpilogueOp::ElementCompute; + using ElementBias = typename Impl::ThreadEpilogueOp::ElementBias; + using ElementT = typename Impl::ThreadEpilogueOp::ElementAux; + + // Constructor inheritance + using Impl::Impl; + + // Host side epilogue arguments + struct [[deprecated("use Sm90TmaWarpSpecialized Arguments instead")]] + Arguments { + struct ThreadArgs { + ElementCompute alpha{1}; + ElementCompute beta{0}; + ElementCompute const *alpha_ptr{nullptr}; + ElementCompute const *beta_ptr{nullptr}; + } thread; + ElementC_ const* ptr_C{nullptr}; + StrideC_ dC{}; + ElementD_* ptr_D{nullptr}; + StrideD_ dD{}; + ElementBias const* ptr_Bias{nullptr}; + ElementT* ptr_T{nullptr}; + + CUTLASS_HOST_DEVICE + operator typename Impl::Arguments() const { + typename Impl::Arguments arguments; + arguments.thread.alpha = thread.alpha; + arguments.thread.beta = thread.beta; + arguments.thread.alpha_ptr = thread.alpha_ptr; + arguments.thread.beta_ptr = thread.beta_ptr; + if constexpr (not cute::is_void_v) { + arguments.thread.bias_ptr = ptr_Bias; + } + if constexpr (not cute::is_void_v) { + arguments.thread.aux_ptr = ptr_T; + arguments.thread.dAux = dD; + } + arguments.ptr_C = ptr_C; + arguments.dC = dC; + arguments.ptr_D = ptr_D; + arguments.dD = dD; + + return arguments; + } + }; + +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace collective +} // namespace epilogue +} // namespace cutlass + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/dispatch_policy.hpp b/include/cutlass/epilogue/dispatch_policy.hpp new file mode 100644 index 0000000000..a5f47f0832 --- /dev/null +++ b/include/cutlass/epilogue/dispatch_policy.hpp @@ -0,0 +1,196 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ +#pragma once + +#include "cutlass/numeric_conversion.h" +#include "cutlass/epilogue/thread/scale_type.h" + +////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue { + +////////////////////////////////////////////////////////////////////////////// + +////////////////////////////////////////////////////////////////////////////// +// +// Builder Epilogue Schedules +// +////////////////////////////////////////////////////////////////////////////// + +struct PtrArrayDefault {}; +struct EpilogueSimtVectorized {}; +struct EpiloguePtrArraySimtVectorized {}; +struct NoSmemWarpSpecialized {}; +struct PtrArrayNoSmemWarpSpecialized {}; +struct PtrArrayNoSmemWarpSpecializedTransposed {}; +struct PtrArrayPlanarComplexNoSmemWarpSpecialized {}; +struct TmaWarpSpecialized {}; +struct TmaWarpSpecializedCooperative {}; +struct PtrArrayTmaWarpSpecializedCooperative { + static constexpr int NumEpilogueWarpGroups = 2; +}; + +// Standard warp specialized epilogue +struct PtrArrayTmaWarpSpecialized { + static constexpr int NumEpilogueWarpGroups = 1; +}; + +// Pingpong kernel epilogue +struct PtrArrayTmaWarpSpecializedPingpong { + static constexpr int NumEpilogueWarpGroups = 2; +}; + +// DEPRECATED schedules, will be removed in next release +struct TmaWarpSpecializedElementwiseBase : public TmaWarpSpecialized {}; +struct TmaWarpSpecializedCooperativeElementwiseBase : public TmaWarpSpecializedCooperative {}; +template < + template class ActivationFunctor_, + thread::ScaleType::Kind Scale_ = thread::ScaleType::Default, + FloatRoundStyle Round_ = FloatRoundStyle::round_to_nearest +> +struct [[deprecated("Use TmaWarpSpecialized with fusion::LinCombEltAct instead")]] +TmaWarpSpecializedElementwise : public TmaWarpSpecializedElementwiseBase { + template + using ActivationFunctor = ActivationFunctor_; + static constexpr thread::ScaleType::Kind Scale = Scale_; + static constexpr FloatRoundStyle Round = Round_; +}; + +template < + template class ActivationFunctor_, + thread::ScaleType::Kind Scale_ = thread::ScaleType::Default, + FloatRoundStyle Round_ = FloatRoundStyle::round_to_nearest +> +struct [[deprecated("Use TmaWarpSpecializedCooperative with fusion::LinCombEltAct instead")]] +TmaWarpSpecializedCooperativeElementwise : public TmaWarpSpecializedCooperativeElementwiseBase { + template + using ActivationFunctor = ActivationFunctor_; + static constexpr thread::ScaleType::Kind Scale = Scale_; + static constexpr FloatRoundStyle Round = Round_; +}; + +struct TmaWarpSpecializedBiasElementwiseBase : public TmaWarpSpecialized{}; +struct TmaWarpSpecializedCooperativeBiasElementwiseBase : public TmaWarpSpecializedCooperative {}; + +template < + template class ActivationFunctor_, + class ElementT_, + template class BiasOp_, + bool StoreT_, + class ElementBias_ +> +struct [[deprecated("Use TmaWarpSpecialized with fusion::LinCombPerRowBiasEltActAux instead")]] +TmaWarpSpecializedBiasElementwise : public TmaWarpSpecializedBiasElementwiseBase { + template + using ActivationFunctor = ActivationFunctor_; + using ElementT = ElementT_; + + template + using BiasOp = BiasOp_; + + static constexpr bool StoreT = StoreT_; + using ElementBias = ElementBias_; +}; + +template < + template class ActivationFunctor_, + class ElementT_, + template class BiasOp_, + bool StoreT_, + class ElementBias_ +> +struct [[deprecated("Use TmaWarpSpecializedCooperative with fusion::LinCombPerRowBiasEltActAux instead")]] +TmaWarpSpecializedCooperativeBiasElementwise : public TmaWarpSpecializedCooperativeBiasElementwiseBase { + template + using ActivationFunctor = ActivationFunctor_; + + using ElementT = ElementT_; + + template + using BiasOp = BiasOp_; + + static constexpr bool StoreT = StoreT_; + using ElementBias = ElementBias_; +}; + +////////////////////////////////////////////////////////////////////////////// +// +// Collective Dispatch Policies +// +////////////////////////////////////////////////////////////////////////////// + +template< + int StagesC_, + int StagesD_, + int FragmentSize_, + bool ReuseSmemC_, + bool DelayTmaStore_ +> +struct Sm90TmaWarpSpecialized { + constexpr static int StagesC = StagesC_; + constexpr static int StagesD = StagesD_; + constexpr static int FragmentSize = FragmentSize_; + constexpr static bool ReuseSmemC = ReuseSmemC_; + constexpr static bool DelayTmaStore = DelayTmaStore_; +}; + +template< + int StagesC_, + int StagesD_, + int FragmentSize_, + bool ReuseSmemC_, + bool DelayTmaStore_, + int NumEpilogueWarpGroups_ +> +struct Sm90PtrArrayTmaWarpSpecialized { + constexpr static int StagesC = StagesC_; + constexpr static int StagesD = StagesD_; + constexpr static int FragmentSize = FragmentSize_; + constexpr static bool ReuseSmemC = ReuseSmemC_; + constexpr static bool DelayTmaStore = DelayTmaStore_; + constexpr static int NumEpilogueWarpGroups = NumEpilogueWarpGroups_; +}; + +// DEPRECATED policies, will be removed in next release +template< + int StagesC_, + int StagesD_, + int FragmentSize_ = 2 +> +struct Sm90TmaWarpSpecializedBiasElementwise { + constexpr static int StagesC = StagesC_; + constexpr static int StagesD = StagesD_; + constexpr static int FragmentSize = FragmentSize_; +}; + +////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue diff --git a/include/cutlass/epilogue/fusion/callbacks.hpp b/include/cutlass/epilogue/fusion/callbacks.hpp new file mode 100644 index 0000000000..9ee37234cb --- /dev/null +++ b/include/cutlass/epilogue/fusion/callbacks.hpp @@ -0,0 +1,89 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include "cutlass/detail/dependent_false.hpp" +#include "cutlass/epilogue/fusion/operations.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::fusion { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Dispatch interface for epilogue fusion callbacks +// For visitor fusions, this is just a convenience wrapper to provide metadata and non-nested args. +// It is also valid to just pass visitor callbacks directly to the collective, e.g. fusion::Sm90LinearCombination, +// provided the collective supports a visitor callbacks interface. This is useful for implementing custom fusions. +template < + class DispatchPolicy, // specialize on collective's dispatch policy since callbacks API will depend on collective's algorithm + class Operation, // the fusion operation being performed, e.g. fusion::LinearCombination + class CtaTile_MNK, // computed tile per CTA + class EpilogueTile_MN, // epilogue subtile size + class... Args // callbacks implementation dependent args (e.g. copy atoms, smem layouts) +> +struct FusionCallbacks { + static_assert(cutlass::detail::dependent_false, "Could not find a callbacks specialization."); +}; + +// Metadata helper to handle custom EVTs or other non-FusionCallbacks types +template +struct FusionCallbacksTraits { + using DispatchPolicy = void; + using Operation = T; + using CtaTile_MNK = void; + using EpilogueTile_MN = void; + using ElementCompute = void; +}; + +template < + class DispatchPolicy_, + class Operation_, + class CtaTile_MNK_, + class EpilogueTile_MN_, + class... Args +> +struct FusionCallbacksTraits< + FusionCallbacks +> { + using DispatchPolicy = DispatchPolicy_; + using Operation = Operation_; + using CtaTile_MNK = CtaTile_MNK_; + using EpilogueTile_MN = EpilogueTile_MN_; + using ElementCompute = typename Operation::ElementCompute; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::fusion + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/fusion/operations.hpp b/include/cutlass/epilogue/fusion/operations.hpp new file mode 100644 index 0000000000..1ef06a538b --- /dev/null +++ b/include/cutlass/epilogue/fusion/operations.hpp @@ -0,0 +1,495 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +#pragma once + +#include +#include +#include +#include // cute::false_type + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::fusion { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Fusion Operations +// Template args must not be implementation dependent +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +struct FusionOperation { + // metadata types/queries that can be overrided + using ElementOutput = void; + using ElementCompute = void; + + using ElementSource = void; + static constexpr bool IsSourceSupported = false; + + using ElementScalar = void; + static constexpr int AlignmentScalar = 0; + static constexpr bool IsScaleFactorSupported = false; + static constexpr bool IsPerRowScaleSupported = false; + static constexpr bool IsPerColScaleSupported = false; + + using ElementBias = void; + static constexpr int AlignmentBias = 0; + static constexpr bool IsPerRowBiasSupported = false; + static constexpr bool IsPerColBiasSupported = false; + static constexpr bool IsDePerRowBiasSupported = false; + + using ActivationFn = void; + static constexpr bool IsEltActSupported = false; + static constexpr bool IsDeEltActSupported = false; + + using ElementAux = void; + using GmemLayoutTagAux = void; + static constexpr int AlignmentAux = 0; + static constexpr bool IsAuxOutSupported = false; + static constexpr bool IsAuxInSupported = false; + + using ElementAmax = void; + static constexpr bool IsAbsMaxSupported = false; + +}; + +// D = alpha * acc +template< + class ElementOutput_, + class ElementCompute_, + class ElementScalar_ = ElementCompute_, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct ScaledAcc : FusionOperation { + using ElementOutput = ElementOutput_; + using ElementCompute = ElementCompute_; + using ElementScalar = ElementScalar_; + static constexpr int AlignmentScalar = 1; + static constexpr auto RoundStyle = RoundStyle_; +}; + +// D = alpha * acc + beta * C +template< + class ElementOutput_, + class ElementCompute_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct LinearCombination + : ScaledAcc { + using ElementSource = ElementSource_; + static constexpr bool IsSourceSupported = true; +}; + +// D = activation(alpha * acc + beta * C) +template< + template class ActivationFn_, + class ElementOutput_, + class ElementCompute_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct LinCombEltAct + : LinearCombination { + using ActivationFn = ActivationFn_; + static constexpr bool IsEltActSupported = true; +}; + +// D = softmax(top_k(alpha * acc + beta * C)) +template< + int TopK, + class ElementOutput_, + class ElementCompute_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct LinCombTopKSoftmaxCol + : LinearCombination { +}; + + +// D = alpha * acc + beta * C + per-row bias +template< + class ElementOutput_, + class ElementCompute_, + class ElementBias_ = ElementOutput_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + int AlignmentBias_ = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct LinCombPerRowBias + : LinearCombination { + using ElementBias = ElementBias_; + static constexpr int AlignmentBias = AlignmentBias_; + static constexpr bool IsPerRowBiasSupported = true; +}; + +// D = alpha * acc + beta * C + per-column bias +template< + class ElementOutput_, + class ElementCompute_, + class ElementBias_ = ElementOutput_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + int AlignmentBias_ = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct LinCombPerColBias + : LinearCombination { + using ElementBias = ElementBias_; + static constexpr int AlignmentBias = AlignmentBias_; + static constexpr bool IsPerColBiasSupported = true; +}; + +// D = activation(alpha * acc + beta * C + per-row bias) +template< + template class ActivationFn_, + class ElementOutput_, + class ElementCompute_, + class ElementBias_ = ElementOutput_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + int AlignmentBias_ = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct LinCombPerRowBiasEltAct + : LinCombPerRowBias { + using ActivationFn = ActivationFn_; + static constexpr bool IsEltActSupported = true; +}; + +// D = activation(alpha * acc + beta * C + per-column bias) +template< + template class ActivationFn_, + class ElementOutput_, + class ElementCompute_, + class ElementBias_ = ElementOutput_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + int AlignmentBias_ = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct LinCombPerColBiasEltAct + : LinCombPerColBias { + using ActivationFn = ActivationFn_; + static constexpr bool IsEltActSupported = true; +}; + +// D = activation(alpha * acc + beta * C + per-row bias) +// aux = alpha * acc + beta * C + per-row bias +template< + class GmemLayoutTagAux_, + template class ActivationFn_, + class ElementOutput_, + class ElementCompute_, + class ElementAux_ = ElementOutput_, + class ElementBias_ = ElementOutput_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + int AlignmentAux_ = 128 / cute::sizeof_bits_v, + int AlignmentBias_ = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct LinCombPerRowBiasEltActAux + : LinCombPerRowBiasEltAct { + using ElementAux = ElementAux_; + using GmemLayoutTagAux = GmemLayoutTagAux_; + static constexpr int AlignmentAux = AlignmentAux_; + static constexpr bool IsAuxOutSupported = true; +}; + +// D = activation(alpha * acc + beta * C + per-col bias) +// aux = alpha * acc + beta * C + per-col bias +template< + class GmemLayoutTagAux_, + template class ActivationFn_, + class ElementOutput_, + class ElementCompute_, + class ElementAux_ = ElementOutput_, + class ElementBias_ = ElementOutput_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + int AlignmentAux_ = 128 / cute::sizeof_bits_v, + int AlignmentBias_ = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct LinCombPerColBiasEltActAux + : LinCombPerColBiasEltAct { + using ElementAux = ElementAux_; + using GmemLayoutTagAux = GmemLayoutTagAux_; + static constexpr int AlignmentAux = AlignmentAux_; + static constexpr bool IsAuxOutSupported = true; +}; + +// D = activation(per-row alpha * acc + per-row beta * C + per-row bias) +template< + template class ActivationFn_, + class ElementOutput_, + class ElementCompute_, + class ElementBias_ = ElementOutput_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, // per-row alpha/beta + int AlignmentBias_ = 128 / cute::sizeof_bits_v, + int AlignmentScalar_ = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct PerRowLinCombPerRowBiasEltAct + : LinCombPerRowBiasEltAct { + static constexpr int AlignmentScalar = AlignmentScalar_; + static constexpr bool IsPerRowScaleSupported = true; +}; + +// D = per-column alpha * per-row alpha * acc + beta * C +template< + class ElementOutput_, + class ElementCompute_, + class ElementSource_ = ElementCompute_, + class ElementScalar_ = ElementCompute_, + int AlignmentScalar_ = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct OuterProdLinComb : FusionOperation { + using ElementOutput = ElementOutput_; + using ElementCompute = ElementCompute_; + using ElementSource = ElementSource_; + using ElementScalar = ElementScalar_; + static constexpr int AlignmentScalar = AlignmentScalar_; + static constexpr auto RoundStyle = RoundStyle_; + static constexpr bool IsSourceSupported = true; + static constexpr bool IsPerRowScaleSupported = true; + static constexpr bool IsPerColScaleSupported = true; +}; + +// D = activation(per-col alpha * acc + per-col beta * C + per-column bias) +template< + template class ActivationFn_, + class ElementOutput_, + class ElementCompute_, + class ElementBias_ = ElementOutput_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, // per-row alpha/beta + int AlignmentBias_ = 128 / cute::sizeof_bits_v, + int AlignmentScalar_ = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct PerColLinCombPerColBiasEltAct + : LinCombPerColBiasEltAct { + static constexpr int AlignmentScalar = AlignmentScalar_; + static constexpr bool IsPerColScaleSupported = true; +}; + +// Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-row bias +// if D is fp8 +// D = scale_d * activation(Z) +// else +// D = activation(Z) +template< + template class ActivationFn_, + class ElementOutput_, + class ElementCompute_, + class ElementBias_ = ElementOutput_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + int AlignmentBias_ = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct ScaledLinCombPerRowBiasEltAct + : LinCombPerRowBiasEltAct { + static constexpr bool IsScaleFactorSupported = true; +}; + +// Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-col bias +// if D is fp8 +// D = scale_d * activation(Z) +// else +// D = activation(Z) +template< + template class ActivationFn_, + class ElementOutput_, + class ElementCompute_, + class ElementBias_ = ElementOutput_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + int AlignmentBias_ = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct ScaledLinCombPerColBiasEltAct + : LinCombPerColBiasEltAct { + static constexpr bool IsScaleFactorSupported = true; +}; + +// Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias +// if D is fp8 +// amax_d = max(abs(elements in activation(Z))) +// D = scale_d * activation(Z) +// else +// D = activation(Z) +// if Aux is fp8 +// amax_aux = max(abs(elements in Z)) +// Aux = scale_aux * Z +// else +// Aux = Z +template< + class GmemLayoutTagAux_, + template class ActivationFn_, + class ElementOutput_, + class ElementCompute_, + class ElementAux_ = ElementOutput_, + class ElementAmax_ = ElementCompute_, + class ElementBias_ = ElementOutput_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + int AlignmentAux_ = 128 / cute::sizeof_bits_v, + int AlignmentBias_ = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct ScaledLinCombPerRowBiasEltActAmaxAux + : ScaledLinCombPerRowBiasEltAct { + using ElementAmax = ElementAmax_; + static constexpr bool IsAbsMaxSupported = true; + + using ElementAux = ElementAux_; + using GmemLayoutTagAux = GmemLayoutTagAux_; + static constexpr int AlignmentAux = AlignmentAux_; + static constexpr bool IsAuxOutSupported = true; +}; + +// Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-col bias +// if D is fp8 +// amax_d = max(abs(elements in activation(Z))) +// D = scale_d * activation(Z) +// else +// D = activation(Z) +// if Aux is fp8 +// amax_aux = max(abs(elements in Z)) +// Aux = scale_aux * Z +// else +// Aux = Z +template< + class GmemLayoutTagAux_, + template class ActivationFn_, + class ElementOutput_, + class ElementCompute_, + class ElementAux_ = ElementOutput_, + class ElementAmax_ = ElementCompute_, + class ElementBias_ = ElementOutput_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + int AlignmentAux_ = 128 / cute::sizeof_bits_v, + int AlignmentBias_ = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct ScaledLinCombPerColBiasEltActAmaxAux + : ScaledLinCombPerColBiasEltAct { + using ElementAmax = ElementAmax_; + static constexpr bool IsAbsMaxSupported = true; + + using ElementAux = ElementAux_; + using GmemLayoutTagAux = GmemLayoutTagAux_; + static constexpr int AlignmentAux = AlignmentAux_; + static constexpr bool IsAuxOutSupported = true; +}; + +// Z = Aux +// dY = alpha * acc + beta * C +// D = d_activation(dY, Z) +template< + class GmemLayoutTagAux_, + template class ActivationFn_, + class ElementOutput_, + class ElementCompute_, + class ElementAux_ = ElementOutput_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + int AlignmentAux_ = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct LinCombDeEltAct + : LinearCombination { + using ActivationFn = ActivationFn_; + static constexpr bool IsDeEltActSupported = true; + + using ElementAux = ElementAux_; + using GmemLayoutTagAux = GmemLayoutTagAux_; + static constexpr int AlignmentAux = AlignmentAux_; + static constexpr bool IsAuxInSupported = true; +}; + +// Z = Aux +// dY = alpha * acc + beta * C +// D = d_activation(dY, Z) +// dBias = sum of columns of D +template< + class GmemLayoutTagAux_, + template class ActivationFn_, + class ElementOutput_, + class ElementCompute_, + class ElementAux_ = ElementOutput_, + class ElementBias_ = ElementCompute_, + class ElementSource_ = ElementOutput_, + class ElementScalar_ = ElementCompute_, + int AlignmentAux_ = 128 / cute::sizeof_bits_v, + int AlignmentBias_ = 128 / cute::sizeof_bits_v, + FloatRoundStyle RoundStyle_ = FloatRoundStyle::round_to_nearest +> +struct LinCombDeEltActDePerRowBias + : LinCombDeEltAct { + using ElementBias = ElementBias_; + static constexpr int AlignmentBias = AlignmentBias_; + static constexpr bool IsDePerRowBiasSupported = true; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::fusion + +///////////////////////////////////////////////////////////////////////////////////////////////// + + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp new file mode 100644 index 0000000000..3e57fa0ba6 --- /dev/null +++ b/include/cutlass/epilogue/fusion/sm90_callbacks_tma_warpspecialized.hpp @@ -0,0 +1,2688 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Fusion callbacks specializations for the sm90 TMA warp-specialized (ws) epilogue +*/ + +#pragma once + +#include "cutlass/cutlass.h" + +#include "cute/tensor.hpp" + +#include "cutlass/epilogue/dispatch_policy.hpp" +#include "cutlass/epilogue/fusion/callbacks.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp" + +#include "cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::fusion { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +using Sm90EVT = Sm90TreeVisitor; + +// D = alpha * acc +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class ElementOutput, + class ElementCompute, + class ElementScalar, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::ScaledAcc, + CtaTileShapeMNK, + EpilogueTile +> : Sm90EVT, + Sm90ScalarBroadcast>, + Sm90AccFetch + > { + using Impl = + Sm90EVT, + Sm90ScalarBroadcast>, + Sm90AccFetch + >; + using Operation = fusion::ScaledAcc; + + struct Arguments { + // Give a name and flat ordering to the fusion callback args + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + using StrideAlpha = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + + // Conversion to the args expected by the visitor implementation + // to_underlying_arguments will implicitly call this + operator typename Impl::Arguments() const { + return + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }; // end binary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// D = alpha * acc + beta * C +template< + class ElementOutput, + class ElementCompute, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinearCombination = + Sm90EVT, // beta * C + (alpha * acc) + Sm90ScalarBroadcast>, // beta + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + Sm90ScalarBroadcast>, // alpha + Sm90AccFetch // acc + > + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class ElementOutput, + class ElementCompute, + class ElementSource, + class ElementScalar, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::LinearCombination, + CtaTileShapeMNK, + EpilogueTile +> : Sm90LinearCombination::type, ElementCompute, ElementSource, ElementScalar, RoundStyle> { + + using Impl = Sm90LinearCombination::type, ElementCompute, ElementSource, ElementScalar, RoundStyle>; + using Operation = fusion::LinearCombination; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + operator typename Impl::Arguments() const { + return + { // ternary op : beta * C + (alpha * acc) + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }, // end binary op + {} // ternary args : multiply_add + }; // end ternary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// D = alpha * acc + beta * C, where beta and alpha can be vectors for each batch +template< + class ElementOutput, + class ElementCompute, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinearCombinationPtrArray = + Sm90EVT, // beta * C + (alpha * acc) + Sm90ScalarBroadcastPtrArray>, // beta + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + Sm90ScalarBroadcastPtrArray>, // alpha + Sm90AccFetch // acc + > + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + int NumEpilogueWarpGroups, + class ElementOutput, + class ElementCompute, + class ElementSource, + class ElementScalar, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm90PtrArrayTmaWarpSpecialized, + fusion::LinearCombination, + CtaTileShapeMNK, + EpilogueTile +> : Sm90LinearCombinationPtrArray::type, ElementCompute, ElementSource, ElementScalar, RoundStyle> { + + using Impl = Sm90LinearCombinationPtrArray::type, ElementCompute, ElementSource, ElementScalar, RoundStyle>; + using Operation = fusion::LinearCombination; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + ElementScalar const* const* alpha_ptr_array = nullptr; + ElementScalar const* const* beta_ptr_array = nullptr; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + operator typename Impl::Arguments() const { + return + { // ternary op : beta * C + (alpha * acc) + {{beta}, {beta_ptr}, {beta_ptr_array}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}, {alpha_ptr_array}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }, // end binary op + {} // ternary args : multiply_add + }; // end ternary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// D = activation(alpha * acc + beta * C) +template< + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombEltAct = + Sm90EVT, // activation(beta * C + (alpha * acc)) + Sm90LinearCombination // beta * C + (alpha * acc) + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementSource, + class ElementScalar, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::LinCombEltAct, + CtaTileShapeMNK, + EpilogueTile +> : Sm90LinCombEltAct { + + using Impl = Sm90LinCombEltAct::type, ElementCompute, ElementSource, ElementScalar, RoundStyle>; + using Operation = fusion::LinCombEltAct; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + operator typename Impl::Arguments() const { + return + { // unary op: activation(beta * C + (alpha * acc)) + { // ternary op : beta * C + (alpha * acc) + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }, // end binary op + {} // ternary args : multiply_add + }, // end ternary op + activation // unary args: activation + }; // end unary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// D = activation(alpha * acc + beta * C), where beta and alpha can be vectors for each batch +template< + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombEltActPtrArray = + Sm90EVT, // activation(beta * C + (alpha * acc)) + Sm90LinearCombinationPtrArray // beta * C + (alpha * acc) + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + int NumEpilogueWarpGroups, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementSource, + class ElementScalar, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm90PtrArrayTmaWarpSpecialized, + fusion::LinCombEltAct, + CtaTileShapeMNK, + EpilogueTile +> : Sm90LinCombEltActPtrArray { + + using Impl = Sm90LinCombEltActPtrArray::type, ElementCompute, ElementSource, ElementScalar, RoundStyle>; + using Operation = fusion::LinCombEltAct; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + ElementScalar const* const* alpha_ptr_array = nullptr; + ElementScalar const* const* beta_ptr_array = nullptr; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + operator typename Impl::Arguments() const { + return + { // unary op: activation(beta * C + (alpha * acc)) + { // ternary op : beta * C + (alpha * acc) + {{beta}, {beta_ptr}, {beta_ptr_array}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}, {alpha_ptr_array}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }, // end binary op + {} // ternary args : multiply_add + }, // end ternary op + activation // unary args: activation + }; // end unary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// D = alpha * acc + beta * C + per-row bias +template< + class CtaTileShapeMNK, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombPerRowBias = + Sm90EVT, // beta * C + (alpha * acc + bias) + Sm90ScalarBroadcast>, // beta + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + bias + Sm90ScalarBroadcast>, // alpha + Sm90AccFetch, // acc + Sm90ColBroadcast<0, CtaTileShapeMNK, ElementBias, ElementCompute, Stride<_1,_0,int64_t>, AlignmentBias> // bias + > + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class ElementOutput, + class ElementCompute, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentBias, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::LinCombPerRowBias, + CtaTileShapeMNK, + EpilogueTile +> : Sm90LinCombPerRowBias< + CtaTileShapeMNK, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle> { + using Impl = Sm90LinCombPerRowBias< + CtaTileShapeMNK, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle>; + using Operation = fusion::LinCombPerRowBias< + ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle>; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using StrideBias = Stride<_1,_0,int64_t>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + operator typename Impl::Arguments() const { + return + { // ternary op : beta * C + (alpha * acc + bias) + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // ternary op : alpha * acc + bias + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }; // end ternary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// D = alpha * acc + beta * C + per-column bias +template< + int StagesC, + class CtaTileShapeMNK, + class EpilogueTile, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombPerColBias = + Sm90EVT, // beta * C + (alpha * acc + bias) + Sm90ScalarBroadcast>, // beta + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + bias + Sm90ScalarBroadcast>, // alpha + Sm90AccFetch, // acc + Sm90RowBroadcast<0, CtaTileShapeMNK, ElementBias, ElementCompute, Stride<_0,_1,int64_t>, AlignmentBias> // bias + > + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class ElementOutput, + class ElementCompute, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentBias, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::LinCombPerColBias, + CtaTileShapeMNK, + EpilogueTile +> : Sm90LinCombPerColBias< + StagesC, CtaTileShapeMNK, EpilogueTile, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle> { + using Impl = Sm90LinCombPerColBias< + StagesC, CtaTileShapeMNK, EpilogueTile, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle>; + using Operation = fusion::LinCombPerColBias< + ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle>; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using StrideBias = Stride<_0,_1,int64_t>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + operator typename Impl::Arguments() const { + return + { // ternary op : beta * C + (alpha * acc + bias) + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // ternary op : alpha * acc + bias + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }; // end ternary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// D = activation(alpha * acc + beta * C + per-row bias) +template< + class CtaTileShapeMNK, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombPerRowBiasEltAct = + Sm90EVT, + Sm90LinCombPerRowBias + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentBias, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::LinCombPerRowBiasEltAct< + ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile +> : Sm90LinCombPerRowBiasEltAct< + CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + > { + + using Impl = + Sm90LinCombPerRowBiasEltAct< + CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + >; + using Operation = + fusion::LinCombPerRowBiasEltAct< + ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using StrideBias = Stride<_1,_0,int64_t>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + operator typename Impl::Arguments() const { + return + { // unary op : activation(beta * C + (alpha * acc + bias)) + { // ternary op : beta * C + (alpha * acc + bias) + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // ternary op : alpha * acc + bias + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }, // end ternary op + activation // unary args : activation + }; // end unary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// D = activation(alpha * acc + beta * C + per-column bias) +template< + int StagesC, + class CtaTileShapeMNK, + class EpilogueTile, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombPerColBiasEltAct = + Sm90EVT, + Sm90LinCombPerColBias + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentBias, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::LinCombPerColBiasEltAct< + ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile +> : Sm90LinCombPerColBiasEltAct< + StagesC, CtaTileShapeMNK, EpilogueTile, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + > { + + using Impl = + Sm90LinCombPerColBiasEltAct< + StagesC, CtaTileShapeMNK, EpilogueTile, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + >; + using Operation = + fusion::LinCombPerColBiasEltAct< + ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using StrideBias = Stride<_0,_1,int64_t>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + operator typename Impl::Arguments() const { + return + { // unary op : activation(beta * C + (alpha * acc + bias)) + { // ternary op : beta * C + (alpha * acc + bias) + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // ternary op : alpha * acc + bias + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }, // end ternary op + activation // unary args : activation + }; // end unary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// D = activation(alpha * acc + beta * C + per-row bias) +// Aux = alpha * acc + beta * C + per-row bias) +template< + class CtaTileShapeMNK, + class EpilogueTile, + int Stages, + class StrideAux, + class SmemLayoutAtom, + class CopyOpR2S, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementAux = ElementOutput, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentAux = 128 / sizeof_bits_v, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombPerRowBiasEltActAux = + Sm90EVT, + Sm90EVT, + Sm90LinCombPerRowBias + > + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class GmemLayoutTagAux, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementAux, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentAux, + int AlignmentBias, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile, + class SmemLayoutAtom, + class CopyOpR2S +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::LinCombPerRowBiasEltActAux< + GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, + ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile, + SmemLayoutAtom, + CopyOpR2S +> : Sm90LinCombPerRowBiasEltActAux< + CtaTileShapeMNK, EpilogueTile, StagesD, cutlass::gemm::TagToStrideC_t, SmemLayoutAtom, CopyOpR2S, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + > { + + using Impl = + Sm90LinCombPerRowBiasEltActAux< + CtaTileShapeMNK, EpilogueTile, StagesD, cutlass::gemm::TagToStrideC_t, SmemLayoutAtom, CopyOpR2S, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + >; + using Operation = + fusion::LinCombPerRowBiasEltActAux< + GmemLayoutTagAux, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using StrideBias = Stride<_1,_0,int64_t>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + using StrideAux = cutlass::gemm::TagToStrideC_t; + ElementAux* aux_ptr = nullptr; + StrideAux dAux = {}; + + operator typename Impl::Arguments() const { + return + { // unary op : activation(store(beta * C + (alpha * acc + bias))) + { // unary op : store(beta * C + (alpha * acc + bias)) + { // ternary op : beta * C + (alpha * acc + bias) + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // ternary op : alpha * acc + bias + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }, // end ternary op + {aux_ptr, dAux} // unary args : store + }, // end unary op + activation // unary args : activation + }; // end unary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// D = activation(alpha * acc + beta * C + per_col bias) +// Aux = alpha * acc + beta * C + per_col bias) +template< + int StagesC, + class CtaTileShapeMNK, + class EpilogueTile, + int Stages, + class StrideAux, + class SmemLayoutAtom, + class CopyOpR2S, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementAux = ElementOutput, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentAux = 128 / sizeof_bits_v, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombPerColBiasEltActAux = + Sm90EVT, + Sm90EVT, + Sm90LinCombPerColBias + > + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class GmemLayoutTagAux, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementAux, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentAux, + int AlignmentBias, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile, + class SmemLayoutAtom, + class CopyOpR2S +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::LinCombPerColBiasEltActAux< + GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, + ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile, + SmemLayoutAtom, + CopyOpR2S +> : Sm90LinCombPerColBiasEltActAux< + StagesC, CtaTileShapeMNK, EpilogueTile, StagesD, cutlass::gemm::TagToStrideC_t, SmemLayoutAtom, CopyOpR2S, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + > { + + using Impl = + Sm90LinCombPerColBiasEltActAux< + StagesC, CtaTileShapeMNK, EpilogueTile, StagesD, cutlass::gemm::TagToStrideC_t, SmemLayoutAtom, CopyOpR2S, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + >; + using Operation = + fusion::LinCombPerColBiasEltActAux< + GmemLayoutTagAux, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using StrideBias = Stride<_0,_1,int64_t>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + using StrideAux = cutlass::gemm::TagToStrideC_t; + ElementAux* aux_ptr = nullptr; + StrideAux dAux = {}; + + operator typename Impl::Arguments() const { + return + { // unary op : activation(store(beta * C + (alpha * acc + bias))) + { // unary op : store(beta * C + (alpha * acc + bias)) + { // ternary op : beta * C + (alpha * acc + bias) + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // ternary op : alpha * acc + bias + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }, // end ternary op + {aux_ptr, dAux} // unary args : store + }, // end unary op + activation // unary args : activation + }; // end unary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// D = per-row alpha * acc + per-row beta * C + per-row bias +template< + class CtaTileShapeMNK, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + int AlignmentScalar = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90PerRowLinCombPerRowBias = + Sm90EVT, // beta * C + (alpha * acc + bias) + Sm90ColBroadcast<0, CtaTileShapeMNK, ElementScalar, ElementCompute, Stride, AlignmentScalar>, // beta, dynamic scalar/vector broadcast + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + bias + Sm90ColBroadcast<0, CtaTileShapeMNK, ElementScalar, ElementCompute, Stride, AlignmentScalar>, // alpha, dynamic scalar/vector broadcast + Sm90AccFetch, // acc + Sm90ColBroadcast<0, CtaTileShapeMNK, ElementBias, ElementCompute, Stride<_1,_0,int64_t>, AlignmentBias> // bias + > + >; + +// D = activation(per-row alpha * acc + per-row beta * C + per-row bias) +template< + class CtaTileShapeMNK, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + int AlignmentScalar = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90PerRowLinCombPerRowBiasEltAct = + Sm90EVT, + Sm90PerRowLinCombPerRowBias + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentBias, + int AlignmentScalar, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::PerRowLinCombPerRowBiasEltAct< + ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile +> : Sm90PerRowLinCombPerRowBiasEltAct< + CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle + > { + + using Impl = + Sm90PerRowLinCombPerRowBiasEltAct< + CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle + >; + using Operation = + fusion::PerRowLinCombPerRowBiasEltAct< + ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle + >; + + struct Arguments { + using StrideAlpha = Stride; + using StrideBeta = Stride; + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + StrideAlpha dAlpha = {bool(1), _0{}, 0}; + StrideBeta dBeta = {bool(1), _0{}, 0}; + + using StrideBias = Stride<_1,_0,int64_t>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + operator typename Impl::Arguments() const { + return + { // unary op : activation(beta * C + (alpha * acc + bias)) + { // ternary op : beta * C + (alpha * acc + bias) + {beta_ptr, beta, dBeta}, // leaf args : beta + {}, // leaf args : C + { // ternary op : alpha * acc + bias + {alpha_ptr, alpha, dAlpha}, // leaf args : alpha + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }, // end ternary op + activation // unary args : activation + }; // end unary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// D = per-col alpha * acc + per-col beta * C + per-column bias +template< + int StagesC, + class CtaTileShapeMNK, + class EpilogueTile, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + int AlignmentScalar = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90PerColLinCombPerColBias = + Sm90EVT, // beta * C + (alpha * acc + bias) + Sm90RowBroadcast<0, CtaTileShapeMNK, ElementScalar, ElementCompute, Stride<_0,bool,int64_t>, AlignmentScalar>, // beta, dynamic scalar/vector broadcast + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + bias + Sm90RowBroadcast<0, CtaTileShapeMNK, ElementScalar, ElementCompute, Stride<_0,bool,int64_t>, AlignmentScalar>, // alpha, dynamic scalar/vector broadcast + Sm90AccFetch, // acc + Sm90RowBroadcast<0, CtaTileShapeMNK, ElementBias, ElementCompute, Stride<_0,_1,int64_t>, AlignmentBias> // bias + > + >; + +// D = activation(per-col alpha * acc + per-col beta * C + per-column bias) +template< + int StagesC, + class CtaTileShapeMNK, + class EpilogueTile, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + int AlignmentScalar = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90PerColLinCombPerColBiasEltAct = + Sm90EVT, + Sm90PerColLinCombPerColBias + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentBias, + int AlignmentScalar, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::PerColLinCombPerColBiasEltAct< + ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile +> : Sm90PerColLinCombPerColBiasEltAct< + StagesC, CtaTileShapeMNK, EpilogueTile, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle + > { + + using Impl = + Sm90PerColLinCombPerColBiasEltAct< + StagesC, CtaTileShapeMNK, EpilogueTile, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle + >; + using Operation = + fusion::PerColLinCombPerColBiasEltAct< + ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, AlignmentScalar, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + using StrideAlpha = Stride<_0,bool,int64_t>; + using StrideBeta = Stride<_0,bool,int64_t>; + StrideAlpha dAlpha = {_0{}, bool(1), 0}; + StrideBeta dBeta = {_0{}, bool(1), 0}; + + using StrideBias = Stride<_0,_1,int64_t>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + operator typename Impl::Arguments() const { + return + { // unary op : activation(beta * C + (alpha * acc + bias)) + { // ternary op : beta * C + (alpha * acc + bias) + {beta_ptr, beta, dBeta}, // leaf args : beta + {}, // leaf args : C + { // ternary op : alpha * acc + bias + {alpha_ptr, alpha, dAlpha}, // leaf args : alpha + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }, // end ternary op + activation // unary args : activation + }; // end unary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +constexpr bool is_fp8_v = cute::is_same_v || cute::is_same_v; + +// We only apply the scaling factor if output is fp8 +template +struct ScaleOutOp { template using Op = cutlass::first; }; +template <> +struct ScaleOutOp { template using Op = cutlass::multiplies; }; +template <> +struct ScaleOutOp { template using Op = cutlass::multiplies; }; + +template +using amax = cutlass::maximum_absolute_value_reduction; // propogate nans + +}; // end namespace detail + +// D = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias +template< + class CtaTileShapeMNK, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90ScaledLinCombPerRowBias = + Sm90EVT, // beta * C + (alpha * acc + bias) + Sm90ScalarBroadcast, 2>, // scale_c * beta + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + bias + Sm90ScalarBroadcast, 3>, // scale_a * scale_b * alpha + Sm90AccFetch, // acc + Sm90ColBroadcast<0, CtaTileShapeMNK, ElementBias, ElementCompute, Stride<_1,_0,int64_t>, AlignmentBias> // bias + > + >; + +// Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-row bias +// if D is fp8 +// D = scale_d * activation(Z) +// else +// D = activation(Z) +template< + class CtaTileShapeMNK, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90ScaledLinCombPerRowBiasEltAct = + Sm90EVT::template Op, ElementOutput, ElementCompute, RoundStyle>, // activation(Z) * scale_d + Sm90EVT, // activation(Z) + // Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-row bias + Sm90ScaledLinCombPerRowBias + >, + Sm90ScalarBroadcast // scale_d + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentBias, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::ScaledLinCombPerRowBiasEltAct< + ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile +> : Sm90ScaledLinCombPerRowBiasEltAct< + CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + > { + + using Impl = + Sm90ScaledLinCombPerRowBiasEltAct< + CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + >; + using Operation = + fusion::ScaledLinCombPerRowBiasEltAct< + ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + ElementScalar scale_a = ElementScalar(1); + ElementScalar scale_b = ElementScalar(1); + ElementScalar scale_c = ElementScalar(1); + ElementScalar scale_d = ElementScalar(1); + ElementScalar const* scale_a_ptr = nullptr; + ElementScalar const* scale_b_ptr = nullptr; + ElementScalar const* scale_c_ptr = nullptr; + ElementScalar const* scale_d_ptr = nullptr; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using StrideBias = Stride<_1,_0,int64_t>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + operator typename Impl::Arguments() const { + return + { // binary op : activation((scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias)) * scale_d + { // unary op : activation((scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias)) + { // ternary op : (scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias) + {{beta, scale_c}, + {beta_ptr, scale_c_ptr}, + {dBeta, {_0{}, _0{}, 0}} + }, // leaf args : (scale_c * beta) + {}, // leaf args : C + { // ternary op : (scale_a * scale_b * alpha) * acc + bias + {{alpha, scale_a, scale_b}, + {alpha_ptr, scale_a_ptr, scale_b_ptr}, + {dAlpha, {_0{}, _0{}, 0}, {_0{}, _0{}, 0}} + }, // leaf args : (scale_a * scale_b * alpha) + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }, // end ternary op + activation // unary args : activation + }, // end unary op + {{scale_d}, + {scale_d_ptr} + }, // leaf args : scale_d + {} // binary args : multiplies or first + }; // end binary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// D = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-col bias +template< + class CtaTileShapeMNK, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90ScaledLinCombPerColBias = + Sm90EVT, // beta * C + (alpha * acc + bias) + Sm90ScalarBroadcast, 2>, // scale_c * beta + Sm90SrcFetch, // C + Sm90EVT, // alpha * acc + bias + Sm90ScalarBroadcast, 3>, // scale_a * scale_b * alpha + Sm90AccFetch, // acc + Sm90RowBroadcast<0, CtaTileShapeMNK, ElementBias, ElementCompute, Stride<_0,_1,int64_t>, AlignmentBias> // bias + > + >; + +// Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-col bias +// if D is fp8 +// D = scale_d * activation(Z) +// else +// D = activation(Z) +template< + class CtaTileShapeMNK, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90ScaledLinCombPerColBiasEltAct = + Sm90EVT::template Op, ElementOutput, ElementCompute, RoundStyle>, // activation(Z) * scale_d + Sm90EVT, // activation(Z) + // Z = scale_a * scale_b * alpha * acc + beta * scale_c * C + per-row bias + Sm90ScaledLinCombPerColBias + >, + Sm90ScalarBroadcast // scale_d + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentBias, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::ScaledLinCombPerColBiasEltAct< + ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile +> : Sm90ScaledLinCombPerColBiasEltAct< + CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + > { + + using Impl = + Sm90ScaledLinCombPerColBiasEltAct< + CtaTileShapeMNK, ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + >; + using Operation = + fusion::ScaledLinCombPerColBiasEltAct< + ActivationFn, ElementOutput, ElementCompute, ElementBias, ElementSource, ElementScalar, AlignmentBias, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + ElementScalar scale_a = ElementScalar(1); + ElementScalar scale_b = ElementScalar(1); + ElementScalar scale_c = ElementScalar(1); + ElementScalar scale_d = ElementScalar(1); + ElementScalar const* scale_a_ptr = nullptr; + ElementScalar const* scale_b_ptr = nullptr; + ElementScalar const* scale_c_ptr = nullptr; + ElementScalar const* scale_d_ptr = nullptr; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using StrideBias = Stride<_0,_1,int64_t>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + operator typename Impl::Arguments() const { + return + { // binary op : activation((scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias)) * scale_d + { // unary op : activation((scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias)) + { // ternary op : (scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias) + {{beta, scale_c}, + {beta_ptr, scale_c_ptr}, + {dBeta, {_0{}, _0{}, 0}} + }, // leaf args : (scale_c * beta) + {}, // leaf args : C + { // ternary op : (scale_a * scale_b * alpha) * acc + bias + {{alpha, scale_a, scale_b}, + {alpha_ptr, scale_a_ptr, scale_b_ptr}, + {dAlpha, {_0{}, _0{}, 0}, {_0{}, _0{}, 0}} + }, // leaf args : (scale_a * scale_b * alpha) + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }, // end ternary op + activation // unary args : activation + }, // end unary op + {{scale_d}, + {scale_d_ptr} + }, // leaf args : scale_d + {} // binary args : multiplies or first + }; // end binary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias +// if D is fp8 +// amax_d = max(abs(elements in activation(Z))) +// D = scale_d * activation(Z) +// else +// D = activation(Z) +// if Aux is fp8 +// amax_aux = max(abs(elements in Z)) +// Aux = scale_aux * Z +// else +// Aux = Z + +// fp8 aux specialization +template< + class CtaTileShapeMNK, + class EpilogueTile, + int StagesD, + class StrideAux, + class SmemLayoutAtom, + class CopyOpR2S, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementAux = ElementOutput, + class ElementAmax = ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentAux = 128 / sizeof_bits_v, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90ScaledLinCombPerRowBiasEltActAmaxAuxFp8 = + Sm90SplitTreeVisitor< + // Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias + Sm90ScaledLinCombPerRowBias, + // D = activation(Z) * scale_d, amax_d = max(abs(elements in D)) + Sm90EVT::template Op, ElementOutput, ElementCompute, RoundStyle>, // activation(Z) * scale_d + Sm90EVT, // amax_d + Sm90EVT, // activation(Z) + Sm90SplitTreeFetch // Z + > + >, + Sm90ScalarBroadcast // scale_d + >, + // Aux = Z * scale_aux, amax_aux = max(abs(elements in Aux)) + Sm90EVT, // store(Aux) + Sm90EVT, // Z * scale_aux + Sm90EVT, // amax_aux + Sm90SplitTreeFetch // Z + >, + Sm90ScalarBroadcast // scale_aux + > + > + >; + +// non-fp8 aux specialization +// lets us use some EVT specializations such as relu + uint1b_t aux +template< + class CtaTileShapeMNK, + class EpilogueTile, + int StagesD, + class StrideAux, + class SmemLayoutAtom, + class CopyOpR2S, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementAux = ElementOutput, + class ElementAmax = ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentAux = 128 / sizeof_bits_v, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90ScaledLinCombPerRowBiasEltActAmaxAuxNotFp8 = + // D = activation(Z) * scale_d, amax_d = max(abs(elements in D)) + Sm90EVT::template Op, ElementOutput, ElementCompute, RoundStyle>, // activation(Z) * scale_d + Sm90EVT, // amax_d + Sm90EVT, // activation(Z) + Sm90EVT, // Aux = Z + // Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias + Sm90ScaledLinCombPerRowBias + > + > + >, + Sm90ScalarBroadcast // scale_d + >; + +// dispatcher +template< + class CtaTileShapeMNK, + class EpilogueTile, + int StagesD, + class StrideAux, + class SmemLayoutAtom, + class CopyOpR2S, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementAux = ElementOutput, + class ElementAmax = ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentAux = 128 / sizeof_bits_v, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90ScaledLinCombPerRowBiasEltActAmaxAux = conditional_t, + Sm90ScaledLinCombPerRowBiasEltActAmaxAuxFp8< + CtaTileShapeMNK, EpilogueTile, StagesD, StrideAux, SmemLayoutAtom, CopyOpR2S, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar,AlignmentAux, AlignmentBias, RoundStyle + >, + Sm90ScaledLinCombPerRowBiasEltActAmaxAuxNotFp8< + CtaTileShapeMNK, EpilogueTile, StagesD, StrideAux, SmemLayoutAtom, CopyOpR2S, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + > +>; + + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class GmemLayoutTagAux, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementAux, + class ElementAmax, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentAux, + int AlignmentBias, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile, + class SmemLayoutAtom, + class CopyOpR2S +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::ScaledLinCombPerRowBiasEltActAmaxAux< + GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, + ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile, + SmemLayoutAtom, + CopyOpR2S +> : Sm90ScaledLinCombPerRowBiasEltActAmaxAux< + CtaTileShapeMNK, EpilogueTile, StagesD, cutlass::gemm::TagToStrideC_t, + SmemLayoutAtom, CopyOpR2S, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + > { + + using Impl = + Sm90ScaledLinCombPerRowBiasEltActAmaxAux< + CtaTileShapeMNK, EpilogueTile, StagesD, cutlass::gemm::TagToStrideC_t, + SmemLayoutAtom, CopyOpR2S, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + >; + using Operation = + fusion::ScaledLinCombPerRowBiasEltActAmaxAux< + GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, + ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + ElementScalar scale_a = ElementScalar(1); + ElementScalar scale_b = ElementScalar(1); + ElementScalar scale_c = ElementScalar(1); + ElementScalar scale_d = ElementScalar(1); + ElementScalar const* scale_a_ptr = nullptr; + ElementScalar const* scale_b_ptr = nullptr; + ElementScalar const* scale_c_ptr = nullptr; + ElementScalar const* scale_d_ptr = nullptr; + + ElementScalar scale_aux = ElementScalar(1); + ElementScalar const* scale_aux_ptr = nullptr; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using StrideBias = Stride<_1,_0,int64_t>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + ElementAmax* amax_D_ptr = nullptr; + ElementAmax* amax_aux_ptr = nullptr; + + using StrideAux = cutlass::gemm::TagToStrideC_t; + ElementAux* aux_ptr = nullptr; + StrideAux dAux = {}; + + operator typename Impl::Arguments() const { + // Only compute amax_d if D is fp8 + ElementAmax* amax_D_ptr_ = nullptr; + if constexpr (detail::is_fp8_v) { + amax_D_ptr_ = amax_D_ptr; + } + + // Aux is fp8 -> DAG arguments + if constexpr (detail::is_fp8_v) { + typename Impl::Arguments args; + // always use structured binding to unpack DAG args since it may or may not be a tuple + auto& [Z_args, aux_args, D_args] = args; + + Z_args = + { // ternary op : (scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias) + {{beta, scale_c}, + {beta_ptr, scale_c_ptr}, + {dBeta, {_0{}, _0{}, 0}} + }, // leaf args : (scale_c * beta) + {}, // leaf args : C + { // ternary op : (scale_a * scale_b * alpha) * acc + bias + {{alpha, scale_a, scale_b}, + {alpha_ptr, scale_a_ptr, scale_b_ptr}, + {dAlpha ,{_0{}, _0{}, 0}, {_0{}, _0{}, 0}} + }, // leaf args : (scale_a * scale_b * alpha) + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }; // end ternary op + + D_args = + { // binary op : activation(Z) * scale_d or activation(Z) + { // unary op : reduce(activation(Z)) + { // unary op : activation(Z) + {}, // leaf args : Z + activation // unary args : activation + }, // end unary op + {amax_D_ptr_} // unary args : reduce + }, // end unary op + {{scale_d}, + {scale_d_ptr} + }, // leaf args : scale_d + {} // binary args : multiplies or first + }; // end binary op + + aux_args = + { // unary op : store(Aux) + { // binary op : Z * scale_d or Z + { // unary op : reduce(Z) + {}, // leaf args : Z + {amax_aux_ptr} // unary args : reduce + }, // end unary op + {{scale_aux}, + {scale_aux_ptr} + }, // leaf args : scale_d + {} // binary args : multiplies + }, // end binary op + {aux_ptr, dAux} // unary args : store + }; // end unary op + + return args; + } + + // Aux is not fp8 -> Tree arguments + else { + return + { // binary op : activation(Z) * scale_d or activation(Z) + { // unary op : reduce(activation(Z)) + { // unary op : activation(Z) + { // unary op : store(Z) + { // ternary op : (scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias) + {{beta, scale_c}, + {beta_ptr, scale_c_ptr}, + {dBeta, {_0{}, _0{}, 0}} + }, // leaf args : (scale_c * beta) + {}, // leaf args : C + { // ternary op : (scale_a * scale_b * alpha) * acc + bias + {{alpha, scale_a, scale_b}, + {alpha_ptr, scale_a_ptr, scale_b_ptr}, + {dAlpha, {_0{}, _0{}, 0}} + }, // leaf args : (scale_a * scale_b * alpha) + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias + }, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }, // end ternary op + {aux_ptr, dAux} // unary args : store + }, // end unary op + activation // unary args : activation + }, // end unary op + {amax_D_ptr_} // unary args : reduce + }, // end unary op + {{scale_d},{scale_d_ptr}}, // leaf args : scale_d + {} // binary args : multiplies or first + }; // end binary op + } + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-col bias +// if D is fp8 +// amax_d = max(abs(elements in activation(Z))) +// D = scale_d * activation(Z) +// else +// D = activation(Z) +// if Aux is fp8 +// amax_aux = max(abs(elements in Z)) +// Aux = scale_aux * Z +// else +// Aux = Z + +// fp8 aux specialization +template< + class CtaTileShapeMNK, + class EpilogueTile, + int StagesD, + class StrideAux, + class SmemLayoutAtom, + class CopyOpR2S, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementAux = ElementOutput, + class ElementAmax = ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentAux = 128 / sizeof_bits_v, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90ScaledLinCombPerColBiasEltActAmaxAuxFp8 = + Sm90SplitTreeVisitor< + // Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-col bias + Sm90ScaledLinCombPerColBias, + // D = activation(Z) * scale_d, amax_d = max(abs(elements in D)) + Sm90EVT::template Op, ElementOutput, ElementCompute, RoundStyle>, // activation(Z) * scale_d + Sm90EVT, // amax_d + Sm90EVT, // activation(Z) + Sm90SplitTreeFetch // Z + > + >, + Sm90ScalarBroadcast // scale_d + >, + // Aux = Z * scale_aux, amax_aux = max(abs(elements in Aux)) + Sm90EVT, // store(Aux) + Sm90EVT, // Z * scale_aux + Sm90EVT, // amax_aux + Sm90SplitTreeFetch // Z + >, + Sm90ScalarBroadcast // scale_aux + > + > + >; + +// non-fp8 aux specialization +// lets us use some EVT specializations such as relu + uint1b_t aux +template< + class CtaTileShapeMNK, + class EpilogueTile, + int StagesD, + class StrideAux, + class SmemLayoutAtom, + class CopyOpR2S, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementAux = ElementOutput, + class ElementAmax = ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentAux = 128 / sizeof_bits_v, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90ScaledLinCombPerColBiasEltActAmaxAuxNotFp8 = + // D = activation(Z) * scale_d, amax_d = max(abs(elements in D)) + Sm90EVT::template Op, ElementOutput, ElementCompute, RoundStyle>, // activation(Z) * scale_d + Sm90EVT, // amax_d + Sm90EVT, // activation(Z) + Sm90EVT, // Aux = Z + // Z = scale_a * scale_b * alpha * acc + scale_c * beta * C + per-row bias + Sm90ScaledLinCombPerColBias + > + > + >, + Sm90ScalarBroadcast // scale_d + >; + +// dispatcher +template< + class CtaTileShapeMNK, + class EpilogueTile, + int StagesD, + class StrideAux, + class SmemLayoutAtom, + class CopyOpR2S, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementAux = ElementOutput, + class ElementAmax = ElementCompute, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentAux = 128 / sizeof_bits_v, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90ScaledLinCombPerColBiasEltActAmaxAux = conditional_t, + Sm90ScaledLinCombPerColBiasEltActAmaxAuxFp8< + CtaTileShapeMNK, EpilogueTile, StagesD, StrideAux, SmemLayoutAtom, CopyOpR2S, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar,AlignmentAux, AlignmentBias, RoundStyle + >, + Sm90ScaledLinCombPerColBiasEltActAmaxAuxNotFp8< + CtaTileShapeMNK, EpilogueTile, StagesD, StrideAux, SmemLayoutAtom, CopyOpR2S, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + > +>; + + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class GmemLayoutTagAux, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementAux, + class ElementAmax, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentAux, + int AlignmentBias, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile, + class SmemLayoutAtom, + class CopyOpR2S +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::ScaledLinCombPerColBiasEltActAmaxAux< + GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, + ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile, + SmemLayoutAtom, + CopyOpR2S +> : Sm90ScaledLinCombPerColBiasEltActAmaxAux< + CtaTileShapeMNK, EpilogueTile, StagesD, cutlass::gemm::TagToStrideC_t, + SmemLayoutAtom, CopyOpR2S, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + > { + + using Impl = + Sm90ScaledLinCombPerColBiasEltActAmaxAux< + CtaTileShapeMNK, EpilogueTile, StagesD, cutlass::gemm::TagToStrideC_t, + SmemLayoutAtom, CopyOpR2S, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + >; + using Operation = + fusion::ScaledLinCombPerColBiasEltActAmaxAux< + GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, + ElementAux, ElementAmax, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + ElementScalar scale_a = ElementScalar(1); + ElementScalar scale_b = ElementScalar(1); + ElementScalar scale_c = ElementScalar(1); + ElementScalar scale_d = ElementScalar(1); + ElementScalar const* scale_a_ptr = nullptr; + ElementScalar const* scale_b_ptr = nullptr; + ElementScalar const* scale_c_ptr = nullptr; + ElementScalar const* scale_d_ptr = nullptr; + + ElementScalar scale_aux = ElementScalar(1); + ElementScalar const* scale_aux_ptr = nullptr; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using StrideBias = Stride<_0,_1,int64_t>; + ElementBias const* bias_ptr = nullptr; + StrideBias dBias = {}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + ElementAmax* amax_D_ptr = nullptr; + ElementAmax* amax_aux_ptr = nullptr; + + using StrideAux = cutlass::gemm::TagToStrideC_t; + ElementAux* aux_ptr = nullptr; + StrideAux dAux = {}; + + operator typename Impl::Arguments() const { + // Only compute amax_d if D is fp8 + ElementAmax* amax_D_ptr_ = nullptr; + if constexpr (detail::is_fp8_v) { + amax_D_ptr_ = amax_D_ptr; + } + + // Aux is fp8 -> DAG arguments + if constexpr (detail::is_fp8_v) { + typename Impl::Arguments args; + // always use structured binding to unpack DAG args since it may or may not be a tuple + auto& [Z_args, aux_args, D_args] = args; + + Z_args = + { // ternary op : (scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias) + {{beta, scale_c}, + {beta_ptr, scale_c_ptr}, + {dBeta, {_0{}, _0{}, 0}} + }, // leaf args : (scale_c * beta) + {}, // leaf args : C + { // ternary op : (scale_a * scale_b * alpha) * acc + bias + {{alpha, scale_a, scale_b}, + {alpha_ptr, scale_a_ptr, scale_b_ptr}, + {dAlpha, {_0{}, _0{}, 0}, {_0{}, _0{}, 0}} + }, // leaf args : (scale_a * scale_b * alpha) + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias}, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }; // end ternary op + + D_args = + { // binary op : activation(Z) * scale_d or activation(Z) + { // unary op : reduce(activation(Z)) + { // unary op : activation(Z) + {}, // leaf args : Z + activation // unary args : activation + }, // end unary op + {amax_D_ptr_} // unary args : reduce + }, // end unary op + {{scale_d}, + {scale_d_ptr} + }, // leaf args : scale_d + {} // binary args : multiplies or first + }; // end binary op + + aux_args = + { // unary op : store(Aux) + { // binary op : Z * scale_d or Z + { // unary op : reduce(Z) + {}, // leaf args : Z + {amax_aux_ptr} // unary args : reduce + }, // end unary op + {{scale_aux}, + {scale_aux_ptr} + }, // leaf args : scale_d + {} // binary args : multiplies + }, // end binary op + {aux_ptr, dAux} // unary args : store + }; // end unary op + + return args; + } + + // Aux is not fp8 -> Tree arguments + else { + return + { // binary op : activation(Z) * scale_d or activation(Z) + { // unary op : reduce(activation(Z)) + { // unary op : activation(Z) + { // unary op : store(Z) + { // ternary op : (scale_c * beta) * C + ((scale_a * scale_b * alpha) * acc + bias) + {{beta, scale_c}, + {beta_ptr, scale_c_ptr}, + {dBeta, {_0{}, _0{}, 0}} + }, // leaf args : (scale_c * beta) + {}, // leaf args : C + { // ternary op : (scale_a * scale_b * alpha) * acc + bias + {{alpha, scale_a, scale_b}, + {alpha_ptr, scale_a_ptr, scale_b_ptr}, + {dAlpha, {_0{}, _0{}, 0}, {_0{}, _0{}, 0}} + }, // leaf args : (scale_a * scale_b * alpha) + {}, // leaf args : acc + {bias_ptr, ElementBias(0), dBias + }, // leaf args : bias + {} // ternary args : multiply_add + }, // end ternary op + {} // ternary args : multiply_add + }, // end ternary op + {aux_ptr, dAux} // unary args : store + }, // end unary op + activation // unary args : activation + }, // end unary op + {amax_D_ptr_} // unary args : reduce + }, // end unary op + {{scale_d},{scale_d_ptr}}, // leaf args : scale_d + {} // binary args : multiplies or first + }; // end binary op + } + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class CtaTileShapeMNK, + class EpilogueTile, + int Stages, + class StrideAux, + class SmemLayoutAtom, + class CopyOpS2R, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementAux = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentAux = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombDeEltAct = + Sm90EVT, // activation(beta * C + (alpha * acc), aux) + Sm90LinearCombination, // beta * C + (alpha * acc) + Sm90AuxLoad // aux + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class GmemLayoutTagAux, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementAux, + class ElementSource, + class ElementScalar, + int AlignmentAux, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile, + class SmemLayoutAtom, + class CopyOpS2R +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::LinCombDeEltAct< + GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, + ElementAux, ElementSource, ElementScalar, AlignmentAux, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile, + SmemLayoutAtom, + CopyOpS2R +> : Sm90LinCombDeEltAct< + CtaTileShapeMNK, EpilogueTile, StagesC, cutlass::gemm::TagToStrideC_t, SmemLayoutAtom, CopyOpS2R, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementSource, ElementScalar, AlignmentAux, RoundStyle + > { + + using Impl = + Sm90LinCombDeEltAct< + CtaTileShapeMNK, EpilogueTile, StagesC, cutlass::gemm::TagToStrideC_t, SmemLayoutAtom, CopyOpS2R, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementSource, ElementScalar, AlignmentAux, RoundStyle + >; + using Operation = + fusion::LinCombDeEltAct< + GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, + ElementAux, ElementSource, ElementScalar, AlignmentAux, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + using StrideAux = cutlass::gemm::TagToStrideC_t; + ElementAux const* aux_ptr = nullptr; + StrideAux dAux = {}; + + operator typename Impl::Arguments() const { + return + { // binary op : activation(beta * C + (alpha * acc), aux) + { // ternary op : beta * C + (alpha * acc) + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }, // end binary op + {} // ternary args : multiply_add + }, // end ternary op + {aux_ptr, ElementAux(0), dAux}, // leaf args : aux + activation // binary args : activation + }; // end binary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class CtaTileShapeMNK, + class EpilogueTile, + int Stages, + class StrideAux, + class SmemLayoutAtom, + class CopyOpS2R, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementAux = ElementOutput, + class ElementBias = ElementOutput, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentAux = 128 / sizeof_bits_v, + int AlignmentBias = 128 / sizeof_bits_v, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombDeEltActDePerRowBias = + Sm90EVT, // Identity for final conversion + Sm90EVT, AlignmentBias>, + Sm90LinCombDeEltAct + > + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class GmemLayoutTagAux, + template class ActivationFn, + class ElementOutput, + class ElementCompute, + class ElementAux, + class ElementBias, + class ElementSource, + class ElementScalar, + int AlignmentAux, + int AlignmentBias, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile, + class SmemLayoutAtom, + class CopyOpS2R +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::LinCombDeEltActDePerRowBias< + GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, + ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + >, + CtaTileShapeMNK, + EpilogueTile, + SmemLayoutAtom, + CopyOpS2R +> : Sm90LinCombDeEltActDePerRowBias< + CtaTileShapeMNK, EpilogueTile, StagesC, cutlass::gemm::TagToStrideC_t, SmemLayoutAtom, CopyOpS2R, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + > { + + using Impl = + Sm90LinCombDeEltActDePerRowBias< + CtaTileShapeMNK, EpilogueTile, StagesC, cutlass::gemm::TagToStrideC_t, SmemLayoutAtom, CopyOpS2R, ActivationFn, + ElementOutput, ElementCompute, ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + >; + using Operation = + fusion::LinCombDeEltActDePerRowBias< + GmemLayoutTagAux, ActivationFn, ElementOutput, ElementCompute, + ElementAux, ElementBias, ElementSource, ElementScalar, AlignmentAux, AlignmentBias, RoundStyle + >; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + using StrideAlpha = Stride<_0,_0,int64_t>; + using StrideBeta = Stride<_0,_0,int64_t>; + StrideAlpha dAlpha = {_0{}, _0{}, 0}; + StrideBeta dBeta = {_0{}, _0{}, 0}; + + using ActivationArguments = typename Sm90Compute::Arguments; + ActivationArguments activation = ActivationArguments(); + + using StrideAux = cutlass::gemm::TagToStrideC_t; + ElementAux const* aux_ptr = nullptr; + StrideAux dAux = {}; + + using StrideBias = Stride<_1,_0,int64_t>; + ElementBias* dbias_ptr = nullptr; + StrideBias dDbias = {}; + + operator typename Impl::Arguments() const { + return + { // unary op : identity/convert + { // unary op : reduce(activation(beta * C + (alpha * acc), aux)) + { // binary op : activation(beta * C + (alpha * acc), aux) + { // ternary op : beta * C + (alpha * acc) + {{beta}, {beta_ptr}, {dBeta}}, // leaf args : beta + {}, // leaf args : C + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}, {dAlpha}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }, // end binary op + {} // ternary args : multiply_add + }, // end ternary op + {aux_ptr, ElementAux(0), dAux}, // leaf args : aux + activation // binary args : activation + }, // end binary op + {dbias_ptr, ElementCompute(0), dDbias} // unary args : reduce + }, // end unary op + {} // unary args : identity/convert + }; // end unary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// D = per-column alpha * per-row alpha * acc + beta * c +template< + class CtaTileShapeMNK, + class ElementOutput, + class ElementCompute, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + int AlignmentScalar = 128 / sizeof_bits_v, // Alignment of per-column and per-row scaling vectors + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90OuterProdLinComb = + Sm90EVT, // c(beta) * c(C) + c(alpha * acc) + Sm90ScalarBroadcast>, // beta + Sm90SrcFetch, // C + Sm90EVT, // c(alpha) * c(acc) + Sm90OuterProduct<0, CtaTileShapeMNK, ElementScalar, Stride<_1,_0,int>, Stride<_0,_1,int>, AlignmentScalar>, // alpha_col * alpha_row + Sm90AccFetch // acc + > + >; + +template < + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class ElementOutput, + class ElementCompute, + class ElementSource, + class ElementScalar, + int AlignmentScalar, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + OuterProdLinComb, + CtaTileShapeMNK, + EpilogueTile +> : Sm90OuterProdLinComb { + using Impl = Sm90OuterProdLinComb; + using Operation = OuterProdLinComb; + + struct Arguments { + + // Give a name and flat ordering to the fusion callback args + using StrideCol = Stride<_1,_0,int>; + using StrideRow = Stride<_0,_1,int>; + using StrideBeta = Stride<_0,_0,int>; + ElementScalar const* alpha_ptr_col = nullptr; + ElementScalar const* alpha_ptr_row = nullptr; + ElementScalar beta = static_cast(0); + ElementScalar const* beta_ptr = nullptr; + StrideCol dAlphaCol = {}; + StrideRow dAlphaRow = {}; + StrideBeta dBeta = {}; + + // Conversion to the args expected by the visitor implementation + // to_underlying_arguments will implicitly call this + operator typename Impl::Arguments() const { + return + { + {beta, beta_ptr, dBeta}, // leaf args : beta + {}, // leaf args : C + { + { alpha_ptr_col, alpha_ptr_row, dAlphaCol, dAlphaRow }, // leaf args : alpha cols / rows + {}, // leaf args : acc + {} + }, + {} + }; + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// D = softmax(top_k(alpha * acc + beta * C)) +template< + int TopK, + int FragmentSize, + class CtaTileShapeMNK, + class EpilogueTile, + class ElementOutput, + class ElementCompute, + class ElementSource = ElementOutput, + class ElementScalar = ElementCompute, + FloatRoundStyle RoundStyle = FloatRoundStyle::round_to_nearest +> +using Sm90LinCombTopKSoftmaxCol = + Sm90EVT, // softmax(top_k(beta * C + (alpha * acc))) + Sm90LinearCombination // beta * C + (alpha * acc) + >; + +template < + int TopK, + int StagesC, + int StagesD, + int FragmentSize, + bool ReuseSmemC, + bool DelayTmaStore, + class ElementOutput, + class ElementCompute, + class ElementSource, + class ElementScalar, + FloatRoundStyle RoundStyle, + class CtaTileShapeMNK, + class EpilogueTile +> +struct FusionCallbacks< + epilogue::Sm90TmaWarpSpecialized, + fusion::LinCombTopKSoftmaxCol, + CtaTileShapeMNK, + EpilogueTile +> : Sm90LinCombTopKSoftmaxCol { + + using Impl = Sm90LinCombTopKSoftmaxCol::type, ElementCompute, ElementSource, ElementScalar, RoundStyle>; + using Operation = fusion::LinCombTopKSoftmaxCol; + + struct Arguments { + ElementScalar alpha = ElementScalar(1); + ElementScalar beta = ElementScalar(0); + ElementScalar const* alpha_ptr = nullptr; + ElementScalar const* beta_ptr = nullptr; + + operator typename Impl::Arguments() const { + return + { // unary op: activation(beta * C + (alpha * acc)) + { // ternary op : beta * C + (alpha * acc) + {{beta}, {beta_ptr}}, // leaf args : beta + {}, // leaf args : C + { // binary op : alpha * acc + {{alpha}, {alpha_ptr}}, // leaf args : alpha + {}, // leaf args : acc + {} // binary args : multiplies + }, // end binary op + {} // ternary args : multiply_add + }, // end ternary op + {} // unary args: activation + }; // end unary op + } + }; + + // Ctor inheritance + using Impl::Impl; +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { +template > +struct get_element_aux { + using type = void; +}; + +template +struct get_element_aux> { + using type = typename FusionOpOrCallbacks::ElementAux; +}; + +template +struct get_element_aux, cute::void_t<>> { + using type = typename get_element_aux::type; +}; + +template +struct get_element_aux, cute::void_t::Operation>> { + private: + using Operation = typename FusionCallbacks::Operation; + public: + using type = typename get_element_aux::type; +}; +} // namespace cutlass:epilogue::fusion::detail + +template +using get_element_aux_t = typename detail::get_element_aux::type; + +} // namespace cutlass::epilogue::fusion + +///////////////////////////////////////////////////////////////////////////////////////////////// + + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp new file mode 100644 index 0000000000..321daa6bcc --- /dev/null +++ b/include/cutlass/epilogue/fusion/sm90_visitor_compute_tma_warpspecialized.hpp @@ -0,0 +1,841 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Visitor tree compute operations for the sm90 TMA warp-specialized (ws) epilogue +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/array.h" +#include "cutlass/numeric_conversion.h" +#include "cutlass/epilogue/thread/activation.h" + +#include "cute/tensor.hpp" + +#include "cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp" +#include "cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::fusion { + +using namespace cute; +using namespace detail; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// N-nary Elementwise Compute Operation +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +// The template argument provided for ComputeFn must be able to accept +// exactly one template parameter. In Standard C++, it's OK for +// ComputeFn to have other template parameters, as long as those have +// defaults. For example, the following struct Foo would work. +// +// template +// struct Foo { +// CUTLASS_HOST_DEVICE auto operator() (A a, B b); +// }; +// +// However, some compilers, such as Clang, require that the argument +// take _exactly_ one template parameter. This is nonstandard C++ +// behavior. One work-around for this case is to create a subclass +// with exactly one template parameter, and then use that subclass as +// the template argument. +// +// template +// struct FooHomogeneous : public Foo {}; +// +template< + template class ComputeFn, + class ElementOutput, + class ElementCompute, + FloatRoundStyle RoundStyle, + class = void +> +struct Sm90Compute { +private: + using EmptyArguments = typename Sm90VisitorImpl<>::Arguments; + + template + struct ComputeArguments { + using type = EmptyArguments; + }; + + // partial specialization for compute fns that define an Arguments member, e.g. activation hyperparameters + template + struct ComputeArguments> { + using type = typename Fn::Arguments; + }; + +public: + struct SharedStorage { }; + + using Arguments = typename ComputeArguments>::type; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const&, Arguments const& args, void*) { + return args; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const&, Arguments const&) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_HOST_DEVICE + Sm90Compute() + : params() {} + + CUTLASS_HOST_DEVICE + Sm90Compute(Params const& params, SharedStorage const& shared_storage) + : params(params) {} + + Params const params; + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks(Params const& params) + : params(params) {} + + Params const& params; + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, + Array const&... frg_inputs) { + return transform_apply(cute::make_tuple(frg_inputs...), + [&] (auto&& frg_input) { + using ElementInput = typename cute::remove_cvref_t::Element; + using ConvertInput = NumericArrayConverter; + ConvertInput convert_input{}; + + return convert_input(frg_input); + }, + [&] (auto&&... cvt_frg_inputs) { + using ComputeOutput = ComputeFn>; + ComputeOutput compute_output{}; + + if constexpr (cute::is_same_v) { + using ElementComputeOutput = + typename cute::remove_cvref_t::Element; + using ConvertOutput = NumericArrayConverter; + ConvertOutput convert_output{}; + return convert_output(compute_output(cvt_frg_inputs...)); + } + else { + using ElementComputeOutput = + typename cute::remove_cvref_t::Element; + using ConvertOutput = NumericArrayConverter; + ConvertOutput convert_output{}; + return convert_output(compute_output(cvt_frg_inputs..., params)); + } + } + ); + } + + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + return ConsumerStoreCallbacks(params); + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Performance Optimized Specializations +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +// beta * C + Z +template < + class ElementOutput, + class ElementCompute, + FloatRoundStyle RoundStyle, + class InputScaleOp, // beta + class ElementSource, // C + class InputAddOp // Z +> +struct Sm90TreeVisitor< + Sm90Compute().is_zero())>>, + InputScaleOp, + Sm90SrcFetch, + InputAddOp +> : Sm90VisitorImpl< + InputScaleOp, + Sm90SrcFetch, + InputAddOp, + Sm90Compute + > +{ + using Impl = + Sm90VisitorImpl< + InputScaleOp, + Sm90SrcFetch, + InputAddOp, + Sm90Compute + >; + using Params = typename Impl::Params; + using SharedStorage = typename Impl::SharedStorage; + + CUTLASS_HOST_DEVICE + Sm90TreeVisitor() {} + + CUTLASS_HOST_DEVICE + Sm90TreeVisitor( + Params const& params, + SharedStorage const& shared_storage) + : Impl(params, shared_storage) {} + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + auto const& scale_op = get<0>(Impl::ops); + auto const& added_op = get<2>(Impl::ops); + if constexpr (detail::IsScalarBroadcast::value && not is_void_v) { + return (get<2>(scale_op.params_ptr->dScalar[0]) != 0 && scale_op.params_ptr->scalar_ptrs[0] != nullptr) || + is_C_load_needed() || + added_op.is_producer_load_needed(); + } + else { + return is_C_load_needed() || added_op.is_producer_load_needed(); + } + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + auto const& scale_op = get<0>(Impl::ops); + auto const& src_op = get<1>(Impl::ops); + auto const& added_op = get<2>(Impl::ops); + return (not scale_op.is_zero() && src_op.is_C_load_needed()) || added_op.is_C_load_needed(); + } + + template + struct ConsumerStoreCallbacks : CallbacksImpl { + CUTLASS_DEVICE + ConsumerStoreCallbacks(bool is_C_load_needed, CallbacksImpl&& impl) + : is_C_load_needed(is_C_load_needed), CallbacksImpl(cute::forward(impl)) { } + + bool is_C_load_needed; + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + Array frg_added = get<2>(CallbacksImpl::callbacks_tuple).visit(frg_acc, epi_v, epi_m, epi_n); + + using ElementZ = typename decltype(frg_added)::Element; + using ConvertZ = NumericArrayConverter; + using ConvertI = NumericArrayConverter; + ConvertZ convert_Z{}; + ConvertI convert_I{}; + + Array frg_I = convert_Z(frg_added); + + if constexpr (!is_void_v) { + Array frg_scalar = get<0>(CallbacksImpl::callbacks_tuple).visit(frg_acc, epi_v, epi_m, epi_n); + Array frg_source = get<1>(CallbacksImpl::callbacks_tuple).visit(frg_acc, epi_v, epi_m, epi_n); + + using ElementX = typename decltype(frg_scalar)::Element; + using ElementY = typename decltype(frg_source)::Element; + using ConvertX = NumericArrayConverter; + using ConvertY = NumericArrayConverter; + using ComputeI = multiply_add>; + ConvertX convert_X{}; + ConvertY convert_Y{}; + ComputeI compute_I{}; + + frg_I = compute_I(convert_X(frg_scalar), convert_Y(frg_source), frg_I); + } + + return convert_I(frg_I); + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + auto callbacks_tuple = Impl::template get_consumer_store_callbacks(args); + bool is_C_load_needed = this->is_C_load_needed(); + if (not is_C_load_needed) { + cute::clear(args.tCrC); + } + return ConsumerStoreCallbacks( + is_C_load_needed, std::move(callbacks_tuple)); + } +}; + +// ReLU with aux bit tensor dReLU/dZ +// Aux(i) = Z(i) >= 0 ? 1 : 0 +namespace detail { +// Placeholder node so we can retain standard EVT structure +template +struct Sm90ReLUAuxStore : Sm90VisitorImpl<> { + struct SharedStorage {}; + + struct Arguments { + cutlass::uint1b_t* ptr_aux = nullptr; + StrideMNL dAux = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_HOST_DEVICE + Sm90ReLUAuxStore() { } + + CUTLASS_HOST_DEVICE + Sm90ReLUAuxStore(Params const& params, SharedStorage const& shared_storage) { } +}; +} // namespace detail + +// Specialization on the generic compute+aux EVT +template < + // Compute node + template class Activation, + class ElementOutput, + class ElementCompute, + FloatRoundStyle RoundStyle, + // Aux node + int Stages, + class EpilogueTile, + class StrideMNL, + class SmemLayoutAtom, + class CopyOpR2S, + int Alignment, + bool EnableNullptr, + // Input node + class InputOp +> +struct Sm90TreeVisitor< + Sm90Compute, cutlass::epilogue::thread::ReLu> || + cute::is_same_v, cutlass::epilogue::thread::Clamp> || + cute::is_same_v, cutlass::epilogue::thread::ThresholdReLU> >>, + Sm90TreeVisitor< + Sm90AuxStore< + Stages, + EpilogueTile, + cutlass::uint1b_t, + RoundStyle, + StrideMNL, + SmemLayoutAtom, + CopyOpR2S, + Alignment, + EnableNullptr + >, + InputOp + > +> : Sm90VisitorImpl< + Sm90VisitorImpl< + InputOp, + detail::Sm90ReLUAuxStore + >, + Sm90Compute + > +{ + using Impl = + Sm90VisitorImpl< + Sm90VisitorImpl< + InputOp, + detail::Sm90ReLUAuxStore + >, + Sm90Compute + >; + using Params = typename Impl::Params; + using SharedStorage = typename Impl::SharedStorage; + + CUTLASS_HOST_DEVICE + Sm90TreeVisitor() {} + + CUTLASS_HOST_DEVICE + Sm90TreeVisitor(Params const& params_, SharedStorage const& shared_storage) + : params(params_), Impl(params_, shared_storage) {} + + Params const& params; + + template + struct ConsumerStoreCallbacks : CallbacksImpl { + CUTLASS_DEVICE + ConsumerStoreCallbacks( + RTensor&& tC_rAux, + GTensor&& tC_gAux, + CTensor tC_cAux, + ThrResidue residue_tC_cAux, + Params const& params, + CallbacksImpl&& impl) + : tC_rAux(cute::forward(tC_rAux)), + tC_gAux(cute::forward(tC_gAux)), + tC_cAux(tC_cAux), + residue_tC_cAux(residue_tC_cAux), + params(params), + CallbacksImpl(cute::forward(impl)) {} + + RTensor tC_rAux; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + GTensor tC_gAux; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + CTensor tC_cAux; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + ThrResidue residue_tC_cAux; + Params const& params; + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + // Unpack callbacks + params + auto& [callbacks_input_aux, callbacks_compute] = CallbacksImpl::callbacks_tuple; + auto& [callbacks_input, callbacks_aux] = callbacks_input_aux.callbacks_tuple; + auto const& [params_input_aux, params_compute] = params; + auto const& [params_input, params_aux] = params_input_aux; + + // Visit the input node + Array frg_input = callbacks_input.visit(frg_acc, epi_v, epi_m, epi_n); + + // Compute activation + aux + using ElementInput = typename decltype(frg_input)::Element; + using ConvertInput = NumericArrayConverter; + using ConvertAux = PackPredicates; + using ComputeOutput = Activation; + using ConvertOutput = NumericArrayConverter; + ConvertInput convert_input{}; + ComputeOutput relu{}; + ConvertAux convert_aux{}; + ConvertOutput convert_output{}; + + Array frg_compute = convert_input(frg_input); + bool frg_aux[FragmentSize]; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + ElementCompute pre_relu = frg_compute[i]; + if constexpr (cute::is_same_v, cutlass::epilogue::thread::Clamp> || + cute::is_same_v, cutlass::epilogue::thread::ThresholdReLU>) { + frg_compute[i] = relu(frg_compute[i], params_compute); + } + else { + frg_compute[i] = relu(frg_compute[i]); + } + if constexpr (cute::is_same_v) { + uint32_t aux; + asm volatile("set.equ.u32.f32 %0, %1, %2;\n" : "=r"(aux) : "f"(frg_compute[i]), "f"(pre_relu)); // NaN outputs 1 in Aux + frg_aux[i] = static_cast(aux); + } else if constexpr (cute::is_same_v) { + uint32_t aux; + cutlass::half_t compute = frg_compute[i]; + asm volatile("set.equ.u32.f16 %0, %1, %2;\n" : "=r"(aux) : "h"(compute.raw()), "h"(pre_relu.raw())); // NaN outputs 1 in Aux + frg_aux[i] = static_cast(aux); + } else { + frg_aux[i] = frg_compute[i] == pre_relu; + } + } + + static_assert(FragmentSize % 8 == 0, "Predicate vector must be byte-aligned"); + Tensor tC_rAux_frg = recast(coalesce(tC_rAux(_,_,_,epi_m,epi_n))); // (EPI_V) + tC_rAux_frg(epi_v) = convert_aux(frg_aux); + + return convert_output(frg_compute); + } + + CUTLASS_DEVICE void + end() { + // Unpack callbacks + params + auto& [callbacks_input_aux, callbacks_compute] = CallbacksImpl::callbacks_tuple; + auto& [callbacks_input, callbacks_aux] = callbacks_input_aux.callbacks_tuple; + auto const& [params_input_aux, params_compute] = params; + auto const& [params_input, params_aux] = params_input_aux; + + // Visit the input node + callbacks_input.end(); + + // Nullptr is no-op + if constexpr (EnableNullptr) { + if (params_aux.ptr_aux == nullptr) { + return; + } + } + + // Compute vectorization + constexpr auto MCL = decltype(max_common_layout(tC_rAux, tC_gAux)){}; + constexpr int V = cute::min(Alignment, size(MCL)); + // Copy vectorizes into byte-aligned stores + if constexpr (V > 1 && V % 8 == 0) { + using VecType = uint_bit_t; + Tensor tC_rAux_vec = recast(tC_rAux); + Tensor tC_gAux_vec = recast(tC_gAux); + Tensor tC_cAux_vec = tensor<1>(zipped_divide(tC_cAux, MCL.compose(Int{}))); + auto predicate_fn = [&] (auto&&... coords) { return elem_less(tC_cAux_vec(coords...), residue_tC_cAux); }; + copy_if(predicate_fn, tC_rAux_vec, tC_gAux_vec); + } + // sub-byte vectorization, must serialize threads + else { + // Assumes no inter-warp sharing of bytes (most copy layouts should satisfy this) + int lane_idx = canonical_lane_idx(); + auto predicate_fn = [&] (auto&&... coords) { return elem_less(tC_cAux(coords...), residue_tC_cAux); }; + CUTLASS_PRAGMA_NO_UNROLL + for (int i = 0; i < NumThreadsPerWarp; ++i) { + if (lane_idx == i) { + copy_if(predicate_fn, tC_rAux, tC_gAux); + } + __syncwarp(); + } + } + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + // Unpack params + auto const& [params_input_aux, params_compute] = params; + auto const& [params_input, params_aux] = params_input_aux; + + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + gmem_ptr ptr_aux = make_gmem_ptr(subbyte_iterator(params_aux.ptr_aux)); + Tensor mAux = make_tensor(ptr_aux, make_layout(make_shape(M,N,L), params_aux.dAux)); // (M,N,L) + Tensor gAux = local_tile(mAux, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N) + + Tensor tC_gAux = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + gAux, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tC_rAux = make_tensor(shape(tC_gAux)); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + + auto callbacks_impl = Impl::template get_consumer_store_callbacks(args); + return ConsumerStoreCallbacks( + cute::move(tC_rAux), cute::move(tC_gAux), args.tCcD, args.residue_tCcD, params, cute::move(callbacks_impl)); + } +}; + +// Aux load for uint1b_t +template < + int Stages, + class EpilogueTile, + class StrideMNL, + class SmemLayoutAtom, + class CopyOpS2R, + int Alignment, + bool EnableNullptr +> +struct Sm90AuxLoad< + Stages, + EpilogueTile, + cutlass::uint1b_t, + StrideMNL, + SmemLayoutAtom, + CopyOpS2R, + Alignment, + EnableNullptr +> { + static_assert(Alignment % 128 == 0, "sub-16B alignment not supported yet"); + + struct SharedStorage {}; + + struct Arguments { + cutlass::uint1b_t const* ptr_aux = nullptr; + cutlass::uint1b_t null_default = cutlass::uint1b_t(0); + StrideMNL dAux = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_HOST_DEVICE + Sm90AuxLoad() { } + + CUTLASS_HOST_DEVICE + Sm90AuxLoad(Params const& params, SharedStorage const&) + : params(params) { } + + Params const params; + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks(RTensor&& tC_rAux_, GTensor&& tC_gAux_, CTensor tC_cAux_, ThrResidue residue_tC_cAux_, Params const& params_) + : tC_rAux(cute::forward(tC_rAux_)), + tC_gAux(cute::forward(tC_gAux_)), + tC_cAux(tC_cAux_), + residue_tC_cAux(residue_tC_cAux_), + params(params_) {} + + RTensor tC_rAux; // (CPY,CPY_M,CPY_N,{EPI_M,EPI_N}) + GTensor tC_gAux; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + CTensor tC_cAux; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + ThrResidue residue_tC_cAux; + Params const& params; + + CUTLASS_DEVICE void + begin() { + if constexpr (decltype(cute::rank(tC_rAux))::value == 5) { + if constexpr (EnableNullptr) { + if (params.ptr_aux == nullptr) { + return; + } + } + + constexpr auto MCL = decltype(max_common_layout(tC_rAux, tC_gAux)){}; + constexpr int V = cute::min(Alignment, size(MCL)); + if constexpr (V > 1) { + using VecType = uint_bit_t; + Tensor tC_gAux_vec = recast(tC_gAux); + Tensor tC_rAux_vec = recast(tC_rAux); + Tensor tC_cAux_vec = tensor<1>(zipped_divide(tC_cAux, MCL.compose(Int{}))); + auto predicate_fn = [&] (auto&&... coords) { return elem_less(tC_cAux_vec(coords...), residue_tC_cAux); }; + copy_if(predicate_fn, tC_gAux_vec, tC_rAux_vec); + } + else { + auto predicate_fn = [&] (auto&&... coords) { return elem_less(tC_cAux(coords...), residue_tC_cAux); }; + copy_if(predicate_fn, tC_gAux, tC_rAux); + } + } + } + + CUTLASS_DEVICE void + begin_loop(int epi_m, int epi_n) { + if constexpr (decltype(cute::rank(tC_rAux))::value == 3) { + if constexpr (EnableNullptr) { + if (params.ptr_aux == nullptr) { + return; + } + } + + auto predicate_fn = [&] (auto&&... coords) { return elem_less(tC_cAux(_,_,_,epi_m,epi_n)(coords...), residue_tC_cAux); }; + copy_if(predicate_fn, tC_gAux(_,_,_,epi_m,epi_n), tC_rAux); + } + } + + template + CUTLASS_DEVICE auto + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + using ElementRegister = typename remove_cvref_t::value_type; + if constexpr (decltype(cute::rank(tC_rAux))::value == 3) { + return recast>(coalesce(tC_rAux))(epi_v); + } + else { + return recast>(coalesce(tC_rAux(_,_,_,epi_m,epi_n)))(epi_v); + } + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + gmem_ptr ptr_aux = make_gmem_ptr(subbyte_iterator(params.ptr_aux)); + Tensor mAux = make_tensor(ptr_aux, make_layout(make_shape(M,N,L), params.dAux)); // (M,N,L) + Tensor gAux = local_tile(mAux, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N) + + Tensor tC_gAux = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + gAux, args.epi_tile, args.tiled_copy, args.thread_idx); + + // If byte-unaligned vectorization, store in registers as uint32_t to reduce redundant pack+unpack instruction sequences + constexpr int V = decltype(max_common_vector(tC_gAux.layout(), make_layout(tC_gAux.shape())))::value; + Tensor tC_rAux = [&] () { + if constexpr (V % 8 != 0) { + return make_tensor(take<0,3>(shape(tC_gAux))); // (CPY,CPY_M,CPY_N) + } else { + return make_tensor(shape(tC_gAux)); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + } + }(); + + if constexpr (EnableNullptr) { + if (params.ptr_aux == nullptr) { + fill(tC_rAux, params.null_default); + } + } + + return ConsumerStoreCallbacks( + cute::move(tC_rAux), cute::move(tC_gAux), args.tCcD, args.residue_tCcD, params); + } +}; + +// dReLU specialization +template< + class ElementOutput, + class ElementCompute, + FloatRoundStyle RoundStyle +> +struct Sm90Compute< + cutlass::epilogue::thread::dReLU, + ElementOutput, + ElementCompute, + RoundStyle +> : Sm90VisitorImpl<> { + + using Sm90VisitorImpl<>::Sm90VisitorImpl; + + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, + Array const& frg_input, + Array const& frg_aux) { + using ConvertInput = NumericArrayConverter; + using ComputeOutput = cutlass::epilogue::thread::dReLU>; + using ConvertOutput = NumericArrayConverter; + ConvertInput convert_input{}; + ComputeOutput compute_output{}; + ConvertOutput convert_output{}; + + return convert_output(compute_output(convert_input(frg_input), frg_aux)); // don't convert frg_aux for dReLU + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + return ConsumerStoreCallbacks(); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::fusion + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp new file mode 100644 index 0000000000..66b1086efc --- /dev/null +++ b/include/cutlass/epilogue/fusion/sm90_visitor_load_tma_warpspecialized.hpp @@ -0,0 +1,1581 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Visitor tree load operations for the sm90 TMA warp-specialized (ws) epilogue +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/arch/barrier.h" +#include "cutlass/epilogue/collective/detail.hpp" + +#include "cute/tensor.hpp" +#include "sm90_visitor_tma_warpspecialized.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::fusion { + +using namespace cute; +using namespace detail; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Elementwise Fetch Operations +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +// returns accumulator +struct Sm90AccFetch : Sm90VisitorImpl<> { + + using Sm90VisitorImpl<>::Sm90VisitorImpl; + + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + return frg_acc; + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + return ConsumerStoreCallbacks{}; + } +}; + +// Split tree visitor fetches intermediate results from temporary accumulators +using Sm90SplitTreeFetch = Sm90AccFetch; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// returns C +template +struct Sm90SrcFetch : Sm90VisitorImpl<> { + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return is_C_load_needed(); + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return not is_void_v; + } + + CUTLASS_DEVICE bool + is_zero() const { + return is_void_v; + } + + using Sm90VisitorImpl<>::Sm90VisitorImpl; + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks(SrcTensor const& tCrC) + : tCrC(tCrC) {} + + SrcTensor const& tCrC; // (CPY,CPY_M,CPY_N) + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + return recast>(tCrC)(epi_v); + } + + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + // register type may differ from logical type so we can't assert matching types here + return ConsumerStoreCallbacks(args.tCrC); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Elementwise Load Operations +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + int Stages, + class EpilogueTile, + class Element, + class StrideMNL, + class SmemLayoutAtom, + class CopyOpS2R, + int Alignment = 128 / sizeof_bits_v, + bool EnableNullptr = true // Fallback scalar broadcast for nullptr params +> +struct Sm90AuxLoad { + static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); + + constexpr static bool is_m_major = epilogue::collective::detail::is_m_major(); + // Find the max contiguous layout usable by TMA (if EpilogueTile is a non-compact tiler) + using SmemShapeTma = decltype(make_shape( + max_common_vector(make_layout(get<0>(EpilogueTile{})),make_layout(get<0>(EpilogueTile{}))), + max_common_vector(make_layout(get<1>(EpilogueTile{})),make_layout(get<1>(EpilogueTile{}))))); + using SmemLayoutTma = decltype(tile_to_shape( + SmemLayoutAtom{}, SmemShapeTma{}, + cute::conditional_t, Step<_1,_2>>{} )); + using SmemLayout = decltype(tile_to_shape( + SmemLayoutTma{}, + make_shape(size<0>(shape(EpilogueTile{})), size<1>(shape(EpilogueTile{})), Int{}), + cute::conditional_t, Step<_1,_2,_3>>{} )); + using CopyOpG2S = + SM90_TMA_LOAD + ; + + struct SharedStorage { + alignas(cutlass::detail::alignment_for_swizzle(SmemLayout{})) + array_aligned smem_aux; + }; + + struct Arguments { + Element const* ptr_aux = nullptr; + Element null_default = Element(0); + StrideMNL dAux = {}; + }; + + struct Params { + using TMA_Aux = decltype(make_tma_copy( + CopyOpG2S{}, + make_tensor(make_gmem_ptr(static_cast(nullptr)), repeat_like(StrideMNL{}, int32_t(0)), append<3>(StrideMNL{}, _0{})), + take<0,2>(SmemLayoutTma{}))); + TMA_Aux tma_load_aux; + Element null_default = Element(0); + bool use_default = false; + }; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) + auto problem_shape_mnkl = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_mnkl; + auto M_AUX = + size(M) + ; + Tensor tensor_aux = make_tensor(make_gmem_ptr(args.ptr_aux), make_layout(make_shape(M_AUX,N,L), append<3>(args.dAux, _0{}))); + typename Params::TMA_Aux tma_load_aux = make_tma_copy(CopyOpG2S{}, tensor_aux, take<0,2>(SmemLayoutTma{})); + + bool use_default = false; + if constexpr (EnableNullptr) { + use_default = args.ptr_aux == nullptr; + } + + return Params{tma_load_aux, args.null_default, use_default}; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_HOST_DEVICE + Sm90AuxLoad() { } + + CUTLASS_HOST_DEVICE + Sm90AuxLoad(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms), + smem_aux(const_cast(shared_storage.smem_aux.data())) { } + + Params const* params_ptr; + Element* smem_aux; + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return true; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_zero() const { + return (params_ptr->use_default && params_ptr->null_default == Element(0)); + } + + template + struct ProducerLoadCallbacks : EmptyProducerLoadCallbacks { + CUTLASS_DEVICE + ProducerLoadCallbacks(GTensor&& bGS_gAux, STensor&& bGS_sAux, Params const* params_ptr) + : bGS_gAux(cute::forward(bGS_gAux)), + bGS_sAux(cute::forward(bGS_sAux)), + params_ptr(params_ptr) {} + + GTensor bGS_gAux; // (TMA,TMA_M,TMA_N,EPI_M,EPI_N) + STensor bGS_sAux; // (TMA,TMA_M,TMA_N,PIPE) + Params const* params_ptr; + + CUTLASS_DEVICE void + step(uint64_t* full_mbarrier_ptr, int epi_m, int epi_n, int load_iteration, bool issue_tma_load) { + if constexpr (EnableNullptr) { + if (params_ptr->use_default) { + return; + } + } + + if (issue_tma_load) { + // Increment the expected transaction bytes of the current stage's mbarrier by the subtile's byte-size + constexpr uint32_t copy_bytes = size(take<0,2>(SmemLayout{})) * sizeof_bits_v / 8; + cutlass::arch::ClusterTransactionBarrier::expect_transaction(full_mbarrier_ptr, copy_bytes); + // Issue the TMA load + constexpr uint16_t mcast_mask = 0; + int load_pipe_index = load_iteration % Stages; + copy(params_ptr->tma_load_aux.with(*full_mbarrier_ptr, mcast_mask), + bGS_gAux(_,_,_,epi_m,epi_n), bGS_sAux(_,_,_,load_pipe_index)); + } + } + }; + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + auto coord_shape = + make_coord(m, n, l) + ; + Tensor mAux_mn = params_ptr->tma_load_aux.get_tma_tensor(make_shape(M,N,L)); // (M,N,L) + Tensor mAux = coalesce(mAux_mn, take<0,2>(args.tile_shape_mnk)); + Tensor gAux = local_tile(mAux, take<0,2>(args.tile_shape_mnk), coord_shape); // (CTA_M,CTA_N) + + Tensor gAux_epi = flat_divide(gAux, args.epi_tile); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + Tensor sAux_epi = make_tensor(make_smem_ptr(smem_aux), SmemLayout{}); // (EPI_TILE_M,EPI_TILE_N,PIPE) + + ThrCopy thrblk_g2s = params_ptr->tma_load_aux.get_slice(_0{}); + Tensor bGS_gAux = thrblk_g2s.partition_S(gAux_epi); // (TMA,TMA_M,TMA_N,EPI_M,EPI_N) + Tensor bGS_sAux = thrblk_g2s.partition_D(sAux_epi); // (TMA,TMA_M,TMA_N,PIPE) + + return ProducerLoadCallbacks( + cute::move(bGS_gAux), cute::move(bGS_sAux), params_ptr); + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks(RTensor&& tC_rAux, TiledS2R tiled_s2r, STensorS2R&& tSR_sAux, Params const* params_ptr) + : tC_rAux(cute::forward(tC_rAux)), + tiled_s2r(tiled_s2r), + tSR_sAux(cute::forward(tSR_sAux)), + params_ptr(params_ptr) { } + + TiledS2R tiled_s2r; + RTensor tC_rAux; // (CPY,CPY_M,CPY_N) + STensorS2R tSR_sAux; // (S2R,S2R_M,S2R_N,PIPE) + Params const* params_ptr; + + CUTLASS_DEVICE void + previsit(int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed) { + if constexpr (EnableNullptr) { + if (params_ptr->use_default) { + fill(tC_rAux, params_ptr->null_default); + return; + } + } + + using RLayoutS2R = decltype(cute::layout(TiledS2R{}.get_slice(0).retile_S(RTensor{}))); + Tensor tSR_rAux = make_tensor(tC_rAux.data(), RLayoutS2R{}); // (S2R,S2R_M,S2R_N) + + int load_pipe_index = load_iteration % Stages; + copy(tiled_s2r, tSR_sAux(_,_,_,load_pipe_index), tSR_rAux); + } + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + Tensor tC_rAux_frg = recast>(coalesce(tC_rAux)); // (EPI_V) + + return tC_rAux_frg(epi_v); + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + + auto [M, N, K, L] = args.problem_shape_mnkl; + + Tensor mAux_mn = params_ptr->tma_load_aux.get_tma_tensor(make_shape(M,N,L)); // (M,N,L) + Tensor mAux = coalesce(mAux_mn, take<0,2>(args.tile_shape_mnk)); + Tensor tC_gAux = sm90_partition_for_epilogue(mAux, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tC_rAux = make_tensor(take<0,3>(shape(tC_gAux))); // (CPY,CPY_M,CPY_N) + + auto tiled_s2r = conditional_return( + make_tiled_copy_S(Copy_Atom{}, args.tiled_copy), + make_tiled_copy_D(Copy_Atom{}, args.tiled_copy) + ); + Tensor sAux_epi = cute::as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(smem_aux), SmemLayout{})); // (EPI_TILE_M,EPI_TILE_N,PIPE) + auto tSR_sAux = tiled_s2r.get_slice(args.thread_idx).partition_S(sAux_epi); // (S2R,S2R_M,S2R_N,PIPE) + + return ConsumerStoreCallbacks( + cute::move(tC_rAux), tiled_s2r, cute::move(tSR_sAux), params_ptr); + } +}; + +template < + class Element, + class EpilogueTile, // Unused + class LayoutOrStrideMNL, + class SmemLayoutAtom, // Unused + class CopyOpS2R, // Unused + int Alignment, + bool EnableNullptr +> +struct Sm90AuxLoad< + 0, EpilogueTile, Element, LayoutOrStrideMNL, + SmemLayoutAtom, CopyOpS2R, Alignment, EnableNullptr +> { + using ElementAux = Element; + using StrideMNL = cutlass::gemm::TagToStrideC_t; + + struct SharedStorage { }; + + struct Arguments { + Element const* ptr_aux = nullptr; + Element null_default = Element(0); + StrideMNL dAux = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_HOST_DEVICE + Sm90AuxLoad() { } + + CUTLASS_HOST_DEVICE + Sm90AuxLoad(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms) { } + + Params const* params_ptr; + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template< + class GTensorG2R, + class RTensor, + class CTensorG2R, + class ProblemShapeMNL + > + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks(GTensorG2R&& tC_gAux, + RTensor&& tC_rAux, + CTensorG2R&& tC_cAux, + ProblemShapeMNL problem_shape_mnl, + Params const* params_ptr) + : tC_gAux(cute::forward(tC_gAux)), + tC_rAux(cute::forward(tC_rAux)), + tC_cAux(cute::forward(tC_cAux)), + problem_shape_mnl(problem_shape_mnl), + params_ptr(params_ptr) {} + + GTensorG2R tC_gAux; + RTensor tC_rAux; + CTensorG2R tC_cAux; + ProblemShapeMNL problem_shape_mnl; + Params const* params_ptr; + + CUTLASS_DEVICE void + begin_loop(int epi_m, int epi_n) { + if constexpr (EnableNullptr) { + if (params_ptr->ptr_aux == nullptr) { + fill(tC_rAux, params_ptr->null_default); + return; + } + } + constexpr auto MCL = decltype(max_common_layout(tC_gAux(_,_,_,_0{},_0{}), tC_rAux)){}; + constexpr int V = cute::min(Alignment, size(MCL)); + + Tensor tC_cAux_mn = tC_cAux(_,_,_,epi_m,epi_n); + Tensor tC_cAux_vec = tensor<1>(zipped_divide(coalesce(tC_cAux_mn), MCL.compose(Int{}))); + + Tensor tC_gAux_vec = recast>(coalesce(tC_gAux(_,_,_,epi_m,epi_n))); + Tensor tC_rAux_vec = recast>(coalesce(tC_rAux)); + + auto pred_fn = [&] (auto const&... coords) { + return elem_less(tC_cAux_vec(coords...), problem_shape_mnl); + }; + + copy_if(pred_fn, tC_gAux_vec, tC_rAux_vec); + } + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + return recast>(tC_rAux)(epi_v); + } + }; + + template < + bool ReferenceSrc, + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + + auto problem_shape_mnl = make_shape(M,N,L); + + // Gmem Tensor + Tensor mAux = make_tensor( + make_gmem_ptr(params_ptr->ptr_aux), make_shape(M,N,L), params_ptr->dAux + ); + Tensor tC_gAux = sm90_partition_for_epilogue( + mAux, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); + + // Register Tensor + Tensor tC_rAux = make_tensor(take<0,3>(shape(tC_gAux))); + + // Predication support + Tensor coordAux = make_identity_tensor(shape(mAux)); + Tensor tC_cAux = sm90_partition_for_epilogue( + coordAux, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); + + return ConsumerStoreCallbacks( + cute::move(tC_gAux), + cute::move(tC_rAux), + cute::move(tC_cAux), + problem_shape_mnl, + params_ptr + ); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Broadcast Load Operations +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Scalar broadcast +// Supports reduction over multiple broadcasts to support fusions such as fp8 scaling factors +template< + class Element, + class StrideMNL_ = Stride<_0,_0,_0>, + int BroadcastCount = 1, + template class ReductionFn = multiplies +> +struct Sm90ScalarBroadcast { + using StrideMNL = StrideMNL_; + static_assert(is_static_v(StrideMNL{}))>); // batch stride can be dynamic or static + static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_0>{}); + + struct SharedStorage { }; + + struct Arguments { + Element scalars[BroadcastCount] = {}; + Element const* scalar_ptrs[BroadcastCount] = {}; + StrideMNL dScalar[BroadcastCount] = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter *cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + // This must be called after update_scalar is called + CUTLASS_DEVICE bool + is_zero() const { + if (get<2>(params_ptr->dScalar[0]) == 0) { + // Only 1 batch + return scalar == Element(0); + } + else { + // multiple batch + if (valid_scalar == false) { + // for stridedBatch kernel, if ptr has a valid address, we need to enable the epi_load warps. + return params_ptr->scalar_ptrs[0] == nullptr; + } + else { + // Check whether each batch is ZERO or not. + return scalar == Element(0); + } + } + } + + CUTLASS_HOST_DEVICE + Sm90ScalarBroadcast() { } + + CUTLASS_HOST_DEVICE + Sm90ScalarBroadcast(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms) { + // Get the scalar for non-batched broadcast + if (size<2>(params_ptr->dScalar[0]) == 0) { + update_scalar(); + } + } + + Element scalar; + bool valid_scalar = false; + Params const* params_ptr; + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + // Get the scalar for batched broadcast + if (size<2>(params_ptr->dScalar[0]) != 0) { + auto [m_coord, n_coord, k_coord, l_coord] = args.tile_coord_mnkl; + update_scalar(l_coord); + } + + return EmptyProducerLoadCallbacks{}; + } + + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks(Element scalar) + : scalar(scalar) {} + + Element scalar; + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + Array frg_scalar; + frg_scalar.fill(scalar); + + return frg_scalar; + } + + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + + // Get the scalar for batched broadcast + if (get<2>(params_ptr->dScalar[0]) != 0) { + auto [m_coord, n_coord, k_coord, l_coord] = args.tile_coord_mnkl; + update_scalar(l_coord); + } + + return ConsumerStoreCallbacks(scalar); + } + +private: + CUTLASS_DEVICE void + update_scalar(int l_coord = 0) { + valid_scalar = true; + int l_offset = l_coord * size<2>(params_ptr->dScalar[0]); + + if (params_ptr->scalar_ptrs[0] != nullptr) { + scalar = params_ptr->scalar_ptrs[0][l_offset]; + } + else { + // batch stride is ignored for nullptr fallback + scalar = params_ptr->scalars[0]; + } + + // Do reduction over multiple broadcasts if necessary + ReductionFn reduction_fn; + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < BroadcastCount; ++i) { + if (params_ptr->scalar_ptrs[i] != nullptr) { + int rest_l_offset = l_coord * size<2>(params_ptr->dScalar[i]); + scalar = reduction_fn(scalar, params_ptr->scalar_ptrs[i][rest_l_offset]); + } + else { + // batch stride is ignored for nullptr fallback + scalar = reduction_fn(scalar, params_ptr->scalars[i]); + } + } + } + + template + CUTLASS_DEVICE void + update_scalar(cute::tuple) { + // Only support multiple L-modes with fully-broadcast scalar + scalar = params_ptr->scalars[0]; + valid_scalar = true; + } +}; + +// Scalar broadcast +// Supports reduction over multiple broadcasts to support fusions such as fp8 scaling factors +template< + class Element, + class StrideMNL_ = Stride<_0,_0,_0>, + int BroadcastCount = 1, + template class ReductionFn = multiplies +> +struct Sm90ScalarBroadcastPtrArray { + using StrideMNL = StrideMNL_; + static_assert(is_static_v(StrideMNL{}))>); // batch stride can be dynamic or static + static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_0>{}); + + struct SharedStorage { }; + + struct Arguments { + Element scalars[BroadcastCount] = {}; + Element const* scalar_ptrs[BroadcastCount] = {}; + Element const* const* scalar_ptr_arrays[BroadcastCount] = {}; + StrideMNL dScalar[BroadcastCount] = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter *cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + // producer load is needed if Element is not void + return !cute::is_void_v; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + // This must be called after update_scalar is called + CUTLASS_DEVICE bool + is_zero() const { + return scalar == Element(0); + } + + CUTLASS_HOST_DEVICE + Sm90ScalarBroadcastPtrArray() { } + + CUTLASS_HOST_DEVICE + Sm90ScalarBroadcastPtrArray(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms) { + // Get the scalar for non-batched broadcast + if (size<2>(params_ptr->dScalar[0]) == 0) { + update_scalar(); + } + } + + Element scalar; + Params const* params_ptr; + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + // Get the scalar for batched broadcast + if (size<2>(params_ptr->dScalar[0]) != 0) { + auto [m_coord, n_coord, k_coord, l_coord] = args.tile_coord_mnkl; + update_scalar(l_coord); + } + + return EmptyProducerLoadCallbacks{}; + } + + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks(Element scalar) + : scalar(scalar) {} + + Element scalar; + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + Array frg_scalar; + frg_scalar.fill(scalar); + + return frg_scalar; + } + + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + + // Get the scalar for batched broadcast + if (get<2>(params_ptr->dScalar[0]) != 0) { + auto [m_coord, n_coord, k_coord, l_coord] = args.tile_coord_mnkl; + update_scalar(l_coord); + } + + return ConsumerStoreCallbacks(scalar); + } + +private: + CUTLASS_DEVICE void + update_scalar(int l_coord = 0) { + int l_offset = l_coord * size<2>(params_ptr->dScalar[0]); + + if (params_ptr->scalar_ptr_arrays[0] != nullptr) { + scalar = *(params_ptr->scalar_ptr_arrays[0][l_offset]); + } + else if (params_ptr->scalar_ptrs[0] != nullptr) { + scalar = params_ptr->scalar_ptrs[0][l_offset]; + } + else { + // batch stride is ignored for nullptr fallback + scalar = params_ptr->scalars[0]; + } + + // Do reduction over multiple broadcasts if necessary + ReductionFn reduction_fn; + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < BroadcastCount; ++i) { + + if (params_ptr->scalar_ptr_arrays[i] != nullptr) { + int rest_l_offset = l_coord * size<2>(params_ptr->dScalar[i]); + scalar = reduction_fn(scalar, *(params_ptr->scalar_ptr_arrays[i][rest_l_offset])); + } + if (params_ptr->scalar_ptrs[i] != nullptr) { + int rest_l_offset = l_coord * size<2>(params_ptr->dScalar[i]); + scalar = reduction_fn(scalar, params_ptr->scalar_ptrs[i][rest_l_offset]); + } + else { + // batch stride is ignored for nullptr fallback + scalar = reduction_fn(scalar, params_ptr->scalars[i]); + } + } + } +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +template +[[deprecated("row broadcast only uses 0 stages")]] constexpr int +compute_row_broadcast_stages() { + return ceil_div(StagesC, size<1>(zipped_divide(make_layout(take<0,2>(CtaTileShapeMNK{})), EpilogueTile{}))) + 1; +} + +} + +// Row vector broadcast +template< + int Stages, + class CtaTileShapeMNK, + class ElementInput, + class ElementCompute = ElementInput, + class StrideMNL_ = Stride<_0,_1,_0>, + int Alignment = 128 / sizeof_bits_v, + bool EnableNullptr = true // Fallback scalar broadcast for nullptr params +> +struct Sm90RowBroadcast { + using StrideMNL = StrideMNL_; + static_assert(Stages == 0, "Row broadcast doesn't support smem pipelining"); + + static constexpr bool IsDynamicBroadcast = is_same_v(StrideMNL{}))>, bool>; // row vector or scalar broadcast + static_assert(is_static_v(StrideMNL{}))> || IsDynamicBroadcast); // batch stride can be dynamic or static + static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{} || IsDynamicBroadcast); + + struct SharedStorage { + array_aligned(CtaTileShapeMNK{})> smem; + }; + + struct Arguments { + ElementInput const* ptr_row = nullptr; + ElementInput null_default = ElementInput(0); + StrideMNL dRow = {}; + }; + + struct Params { + ElementInput const* ptr_row = nullptr; + ElementCompute null_default = ElementCompute(0); + StrideMNL dRow = {}; + }; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return {args.ptr_row, ElementCompute(args.null_default), args.dRow}; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_HOST_DEVICE + Sm90RowBroadcast() { } + + CUTLASS_HOST_DEVICE + Sm90RowBroadcast(Params const& params, SharedStorage const& shared_storage) + : params(params), is_zero_(false), + smem(const_cast(shared_storage.smem.data())) { + auto const& [stride_M, stride_N, stride_L] = params.dRow; + // Nullptr default + if (EnableNullptr && params.ptr_row == nullptr) { + is_zero_ = params.null_default == ElementCompute(0); + } + // Dynamic non-batched scalar broadcast + else if (IsDynamicBroadcast && stride_N == bool(0) && stride_L == repeat_like(stride_L, 0)) { + is_zero_ = params.ptr_row[0] == ElementInput(0); + } + } + + Params params; + bool is_zero_ = false; + ElementInput *smem = nullptr; + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_zero() const { + return is_zero_; + } + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks( + GS_GTensor tGS_gRow_, GS_STensor tGS_sRow_, + GS_CTensor tGS_cRow_, Tiled_G2S tiled_g2s_, + SR_STensor tSR_sRow_, SR_RTensor tSR_rRow_, + Residue residue_cRow_, ThrNum thr_num_, Params const& params_) + : tGS_gRow(tGS_gRow_) + , tGS_sRow(tGS_sRow_) + , tGS_cRow(tGS_cRow_) + , tiled_G2S(tiled_g2s_) + , tSR_sRow(tSR_sRow_) + , tSR_rRow(tSR_rRow_) + , residue_cRow(residue_cRow_) + , params(params_) + , is_nullptr(EnableNullptr && params_.ptr_row == nullptr) { + if (is_nullptr) { + fill(tSR_rRow, params.null_default); + } + } + + GS_GTensor tGS_gRow; // (CPY,CPY_M,CPY_N) + GS_STensor tGS_sRow; // (CPY,CPY_M,CPY_N) + GS_CTensor tGS_cRow; // (CPY,CPY_M,CPY_N) + Tiled_G2S tiled_G2S; + + SR_STensor tSR_sRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + SR_RTensor tSR_rRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + + Residue residue_cRow; // (m, n) + ThrNum thr_num; + Params const& params; + bool is_nullptr; + + CUTLASS_DEVICE void + begin() { + if (is_nullptr) { + return; + } + + auto synchronize = [&] () { cutlass::arch::NamedBarrier::sync(thr_num, cutlass::arch::ReservedNamedBarriers::EpilogueBarrier); }; + Tensor tGS_gRow_flt = filter_zeros(tGS_gRow); + Tensor tGS_sRow_flt = filter_zeros(tGS_sRow); + Tensor tGS_cRow_flt = filter_zeros(tGS_cRow, tGS_gRow.stride()); + + for (int i = 0; i < size(tGS_gRow_flt); ++i) { + if (get<1>(tGS_cRow_flt(i)) >= size<1>(CtaTileShapeMNK{})) { + continue; // OOB of SMEM, + } + if (elem_less(tGS_cRow_flt(i), residue_cRow)) { + tGS_sRow_flt(i) = tGS_gRow_flt(i); + } + else { + tGS_sRow_flt(i) = ElementInput(0); // Set to Zero when OOB so LDS can be issued without any preds. + } + } + synchronize(); + } + + CUTLASS_DEVICE void + begin_loop(int epi_m, int epi_n) { + if (epi_m == 0 and not is_nullptr) { // Assumes M-major subtile loop + Tensor tSR_sRow_flt = filter_zeros(tSR_sRow(_,_,_,epi_m,epi_n)); + Tensor tSR_rRow_flt = make_tensor_like(tSR_sRow_flt); + copy_aligned(tSR_sRow_flt, tSR_rRow_flt); + + constexpr int FrgSize = size(tSR_rRow_flt); + using FrgInput = Array; + using FrgCompute = Array; + using ConvertInput = NumericArrayConverter; + + Tensor tSR_rRow_input_frg = recast(coalesce(tSR_rRow_flt)); + Tensor tSR_rRow_compute_frg = recast(filter(tSR_rRow)); + ConvertInput convert_input{}; + + tSR_rRow_compute_frg(_0{}) = convert_input(tSR_rRow_input_frg(_0{})); + } + } + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + Array frg_row; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + frg_row[i] = tSR_rRow(epi_v * FragmentSize + i); + } + + return frg_row; + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + using ThreadCount = decltype(size(args.tiled_copy)); + + auto layout_N = [&] () { + auto shape_N = get<1>(args.problem_shape_mnkl); + if constexpr (IsDynamicBroadcast) { + auto stride_N = repeat_like(shape_N, int(0)); + if (get<1>(params.dRow) == bool(1)) { + stride_N = transform_leaf(compact_major(shape_N), + [] (auto const& stride) { return static_cast(stride); } + ); + } + return make_layout(shape_N, stride_N); + } + else { + return make_layout(shape_N); + } + }(); + + auto layout_M = make_layout(M, repeat_like(M, _0{})); + auto layout_L = make_layout(L, get<2>(params.dRow)); + Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_layout(layout_M,layout_N,layout_L)); + Tensor gRow = local_tile(mRow(_,_,l), take<0,2>(args.tile_shape_mnk), make_coord(m, n)); // (CTA_M, CTA_N) + Tensor sRow = make_tensor(make_smem_ptr(smem), + make_shape(size<0>(CtaTileShapeMNK{}), size<1>(CtaTileShapeMNK{})), make_shape(_0{}, _1{})); // (CTA_M, CTA_N) + //// G2S: Gmem to Smem + auto tiled_g2s = make_tiled_copy(Copy_Atom{}, + Layout< Shape<_1, ThreadCount>, + Stride<_0, _1>>{}, + Layout<_1>{}); + auto thr_g2s = tiled_g2s.get_slice(args.thread_idx); + Tensor tGS_gRow = thr_g2s.partition_S(gRow); + Tensor tGS_sRow = thr_g2s.partition_D(sRow); + + //// G2S: Coord + Tensor tGS_cRow = thr_g2s.partition_S(args.cD); + + //// S2R: Smem to Reg + Tensor tSR_sRow = sm90_partition_for_epilogue(sRow, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tSR_rRow = make_tensor_like(take<0,3>(tSR_sRow)); // (CPY,CPY_M,CPY_N) + + return ConsumerStoreCallbacks( + tGS_gRow, + tGS_sRow, + tGS_cRow, tiled_g2s, + tSR_sRow, + tSR_rRow, + args.residue_cD, + ThreadCount{}, + params); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Column vector broadcast +template< + int Stages, + class CtaTileShapeMNK, + class ElementInput, + class ElementCompute = ElementInput, + class StrideMNL_ = Stride<_1,_0,_0>, + int Alignment = 128 / sizeof_bits_v, + bool EnableNullptr = true // Fallback scalar broadcast for nullptr params +> +struct Sm90ColBroadcast { + using StrideMNL = StrideMNL_; + static_assert(Stages == 0, "Column broadcast doesn't support smem pipelining"); + + static constexpr bool IsDynamicBroadcast = is_same_v(StrideMNL{}))>, bool>; // Column vector or scalar broadcast + static_assert(is_static_v(StrideMNL{}))> || IsDynamicBroadcast); // batch stride can be dynamic or static + static_assert(take<0,2>(StrideMNL{}) == Stride<_1,_0>{} || IsDynamicBroadcast); + + // Accumulator distributes col elements evenly amongst threads so we can just directly load from gmem + struct SharedStorage { }; + + struct Arguments { + ElementInput const* ptr_col = nullptr; + ElementInput null_default = ElementInput(0); + StrideMNL dCol = {}; + }; + + struct Params { + ElementInput const* ptr_col = nullptr; + ElementCompute null_default = ElementCompute(0); + StrideMNL dCol = {}; + }; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return {args.ptr_col, ElementCompute(args.null_default), args.dCol}; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_zero() const { + return is_zero_; + } + + CUTLASS_HOST_DEVICE + Sm90ColBroadcast() { } + + CUTLASS_HOST_DEVICE + Sm90ColBroadcast(Params const& params, SharedStorage const& shared_storage) + : params(params), is_zero_(false) { + auto const& [stride_M, stride_N, stride_L] = params.dCol; + // Nullptr default + if (EnableNullptr && params.ptr_col == nullptr) { + is_zero_ = params.null_default == ElementCompute(0); + } + // Dynamic non-batched scalar broadcast + else if (IsDynamicBroadcast && stride_M == bool(0) && stride_L == repeat_like(stride_L, 0)) { + is_zero_ = params.ptr_col[0] == ElementInput(0); + } + } + + Params params; + bool is_zero_; + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks(GTensor tCgCol_, RTensor tCrCol_, CTensor tCcCol_, ThrResidue residue_tCcCol_, Params const& params_) + : tCgCol(tCgCol_), + tCrCol(tCrCol_), + tCcCol(tCcCol_), + residue_tCcCol(residue_tCcCol_), + params(params_) { + if (EnableNullptr && params.ptr_col == nullptr) { + fill(tCrCol, params.null_default); + } + } + + GTensor tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + RTensor tCrCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + CTensor tCcCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + ThrResidue residue_tCcCol; + Params const& params; + + CUTLASS_DEVICE void + begin() { + if (EnableNullptr && params.ptr_col == nullptr) { + return; + } + + // Filter so we don't issue redundant copies over stride-0 modes + // (only works if 0-strides are in same location, which is by construction) + Tensor tCgCol_flt = filter_zeros(tCgCol); + Tensor tCrCol_flt = make_tensor_like(filter_zeros(tCrCol)); + Tensor tCcCol_flt = filter_zeros(tCcCol, tCgCol.stride()); + + constexpr auto MCL = decltype(max_common_layout(tCgCol_flt, tCrCol_flt)){}; + constexpr int V = cute::min(Alignment, size(MCL)); + if constexpr (V > 1) { + using VecType = uint_bit_t>; + Tensor tCgCol_vec = recast(coalesce(tCgCol_flt)); + Tensor tCrCol_vec = recast(coalesce(tCrCol_flt)); + Tensor tCcCol_vec = tensor<1>(zipped_divide(tCcCol_flt, MCL.compose(Int{}))); + auto pred_fn = [&] (auto const&... coords) { return elem_less(tCcCol_vec(coords...), residue_tCcCol); }; + copy_if(pred_fn, tCgCol_vec, tCrCol_vec); + } + else { + auto pred_fn = [&] (auto const&... coords) { return elem_less(tCcCol_flt(coords...), residue_tCcCol); }; + copy_if(pred_fn, tCgCol_flt, tCrCol_flt); + } + + constexpr int FrgSize = size(tCrCol_flt); + using FrgInput = Array; + using FrgCompute = Array; + using ConvertInput = NumericArrayConverter; + + Tensor tCrCol_input_frg = recast(coalesce(tCrCol_flt)); + Tensor tCrCol_compute_frg = recast(filter(tCrCol)); + ConvertInput convert_input{}; + + tCrCol_compute_frg(_0{}) = convert_input(tCrCol_input_frg(_0{})); + } + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + Array frg_col; + Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + frg_col[i] = tCrCol_mn(epi_v * FragmentSize + i); + } + + return frg_col; + } + + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + + auto [M, N, K, L] = args.problem_shape_mnkl; + auto layout_M = [&] () { + auto shape_M = get<0>(args.problem_shape_mnkl); + if constexpr (IsDynamicBroadcast) { + auto stride_M = repeat_like(shape_M, int(0)); + if (get<0>(params.dCol) == bool(1)) { + stride_M = transform_leaf(compact_major(shape_M), + [] (auto const& stride) { return static_cast(stride); } + ); + } + return make_layout(shape_M, stride_M); + } + else { + return make_layout(shape_M); + } + }(); + + auto layout_N = make_layout(N, repeat_like(N, _0{})); + auto layout_L = make_layout(L, get<2>(params.dCol)); + Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_layout(layout_M,layout_N,layout_L)); + Tensor tCgCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); + + Tensor mCol_static = make_tensor(make_gmem_ptr(params.ptr_col), make_layout(make_layout(M),layout_N,layout_L)); + Tensor tCgCol_static = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + mCol_static, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tCrCol = make_tensor_like(tCgCol_static); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + + return ConsumerStoreCallbacks(tCgCol, tCrCol, args.tCcD, args.residue_tCcD, params); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Do outer product from the column and row loaded +// +template< + int Stages, + class CtaTileShapeMNK, + class ElementScalar, + class StrideColMNL_ = Stride<_1,_0,int64_t>, /// NOTE: Batched scaling untested for now + class StrideRowMNL_ = Stride<_0,_1,int64_t>, + int Alignment = 128 / sizeof_bits_v, + bool EnableNullptr = false // Fallback scalar broadcast for nullptr params +> +struct Sm90OuterProduct { + using StrideColMNL = StrideColMNL_; + using StrideRowMNL = StrideRowMNL_; + static_assert(Stages == 0, "OuterProduct doesn't support smem usage"); + static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); + static_assert(!EnableNullptr, "Nullptr fallback not implemented"); + static_assert(is_static_v(StrideColMNL{}))> && + is_static_v(StrideRowMNL{}))>, "Only batch stride can be dynamic"); + static_assert(take<0,2>(StrideColMNL{}) == Stride<_1,_0>{} && + take<0,2>(StrideRowMNL{}) == Stride<_0,_1>{}, "Row and column incorrectly formatted"); + + // Accumulator distributes col/row elements evenly amongst threads so we can just directly load from gmem + struct SharedStorage { }; + + struct Arguments { + ElementScalar const* ptr_col = nullptr; + ElementScalar const* ptr_row = nullptr; + StrideColMNL dCol = {}; + StrideRowMNL dRow = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_zero() const { + return false; + } + + CUTLASS_HOST_DEVICE + Sm90OuterProduct() { } + + CUTLASS_HOST_DEVICE + Sm90OuterProduct(Params const& params, SharedStorage const& shared_storage) + : params(params) { } + + Params params; + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template< + class GTensorCol, class RTensorCol, + class GTensorRow, class RTensorRow + > + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks(GTensorCol&& tCgCol, RTensorCol&& tCrCol, + GTensorRow&& tCgRow, RTensorRow&& tCrRow, + Params const& params) + : tCgCol(cute::forward(tCgCol)) + , tCrCol(cute::forward(tCrCol)) + , tCgRow(cute::forward(tCgRow)) + , tCrRow(cute::forward(tCrRow)) + , params(params) {} + + GTensorCol tCgCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + RTensorCol tCrCol; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + GTensorRow tCgRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + RTensorRow tCrRow; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + Params const& params; + + CUTLASS_DEVICE void + begin() { + + // Filter so we don't issue redundant copies over stride-0 modes + copy(filter(tCgCol), filter(tCrCol)); + copy(filter(tCgRow), filter(tCrRow)); + } + + template + CUTLASS_DEVICE Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + Array frg_colrow; + Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n); + Tensor tCrRow_mn = tCrRow(_,_,_,epi_m,epi_n); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + frg_colrow[i] = static_cast(tCrCol_mn(epi_v * FragmentSize + i) * tCrRow_mn(epi_v * FragmentSize + i)); + } + return frg_colrow; + } + + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + + auto [M, N, K, L] = args.problem_shape_mnkl; + Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol); + Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow); + Tensor tCgCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + mCol, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tCgRow = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + mRow, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + Tensor tCrRow = make_tensor_like(tCgRow); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + + return ConsumerStoreCallbacks< + decltype(tCgCol), decltype(tCrCol), + decltype(tCgRow), decltype(tCrRow) + >( + cute::move(tCgCol), cute::move(tCrCol), + cute::move(tCgRow), cute::move(tCrRow), + params + ); + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Batch matrix broadcast +// Only need to redefine this if we can multicast across cluster L +template < + int Stages, + class EpilogueTile, + class Element, + class StrideMNL, + class SmemLayoutAtom, + class CopyOpS2R, + int Alignment = 128 / sizeof_bits_v, + bool EnableNullptr = true // Fallback scalar broadcast for nullptr params +> +using Sm90MatrixBroadcast + = Sm90AuxLoad; + +namespace detail { + +template +struct IsScalarBroadcast { + static constexpr bool value = false; +}; + +template +struct IsScalarBroadcast(typename Operation::StrideMNL{})), Stride<_0,_0>>>> { + static constexpr bool value = true; +}; + +} + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::fusion + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp new file mode 100644 index 0000000000..83cfc030df --- /dev/null +++ b/include/cutlass/epilogue/fusion/sm90_visitor_store_tma_warpspecialized.hpp @@ -0,0 +1,1724 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Visitor tree store operations for the sm90 TMA warp-specialized (ws) epilogue +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/workspace.h" + +#include "cute/tensor.hpp" +#include "sm90_visitor_tma_warpspecialized.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::fusion { + +using namespace cute; +using namespace detail; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Elementwise Store Operations +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + int Stages, + class EpilogueTile, + class Element, + FloatRoundStyle RoundStyle, + class StrideMNL, + class SmemLayoutAtom, + class CopyOpR2S, + int Alignment = 128 / sizeof_bits_v, + bool EnableNullptr = true // Noop on nullptr params +> +struct Sm90AuxStore { + using ElementAux = Element; + static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); + + constexpr static bool is_m_major = epilogue::collective::detail::is_m_major(); + // Find the max contiguous layout usable by TMA (if EpilogueTile is a non-compact tiler) + using SmemShapeTma = decltype(make_shape( + max_common_vector(make_layout(get<0>(EpilogueTile{})),make_layout(get<0>(EpilogueTile{}))), + max_common_vector(make_layout(get<1>(EpilogueTile{})),make_layout(get<1>(EpilogueTile{}))))); + using SmemLayoutTma = decltype(tile_to_shape( + SmemLayoutAtom{}, SmemShapeTma{}, + cute::conditional_t, Step<_1,_2>>{} )); + using SmemLayout = decltype(tile_to_shape( + SmemLayoutTma{}, + make_shape(size<0>(shape(EpilogueTile{})), size<1>(shape(EpilogueTile{})), Int{}), + cute::conditional_t, Step<_1,_2,_3>>{} )); + + struct SharedStorage { + alignas(cutlass::detail::alignment_for_swizzle(SmemLayout{})) + array_aligned smem_aux; + }; + + struct Arguments { + Element* ptr_aux = nullptr; + StrideMNL dAux = {}; + }; + + struct Params { + using TMA_Aux = decltype(make_tma_copy( + SM90_TMA_STORE{}, + make_tensor(static_cast(nullptr), repeat_like(StrideMNL{}, int32_t(0)), StrideMNL{}), + SmemLayoutTma{})); + TMA_Aux tma_store_aux; + bool is_nullptr = false; + }; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + // Optionally append 1s until problem shape is rank-4 in case its is only rank-3 (MNK) + auto problem_shape_mnkl = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_mnkl; + + bool is_nullptr = false; + if constexpr (EnableNullptr) { + is_nullptr = args.ptr_aux == nullptr; + } + + typename Params::TMA_Aux tma_store_aux; + if (not is_nullptr) { + Tensor tensor_aux = make_tensor(args.ptr_aux, make_layout(make_shape(M,N,L), args.dAux)); + tma_store_aux = make_tma_copy(SM90_TMA_STORE{}, tensor_aux, SmemLayoutTma{}); + } + + return {tma_store_aux, is_nullptr}; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_HOST_DEVICE + Sm90AuxStore() { } + + CUTLASS_HOST_DEVICE + Sm90AuxStore(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms), + smem_aux(const_cast(shared_storage.smem_aux.data())) { } + + Params const* params_ptr; + Element* smem_aux; + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template < + class RTensor, + class TiledR2S, + class STensorR2S, + class STensorS2G, + class GTensorS2G + > + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks( + RTensor&& tC_rAux, + TiledR2S tiled_r2s, + STensorR2S&& tRS_sAux, + STensorS2G&& bSG_sAux, + GTensorS2G&& bSG_gAux, + Params const* params_ptr) + : tiled_r2s(tiled_r2s), + tC_rAux(cute::forward(tC_rAux)), + tRS_sAux(cute::forward(tRS_sAux)), + bSG_sAux(cute::forward(bSG_sAux)), + bSG_gAux(cute::forward(bSG_gAux)), + params_ptr(params_ptr) {} + + TiledR2S tiled_r2s; + RTensor tC_rAux; // (CPY,CPY_M,CPY_N) + STensorR2S tRS_sAux; // (R2S,R2S_M,R2S_N,PIPE) + STensorS2G bSG_sAux; // (S2G,S2G_M,S2G_N,PIPE) + GTensorS2G bSG_gAux; // (S2G,S2G_M,S2G_N,EPI_M,EPI_N) + Params const* params_ptr; + + template + CUTLASS_DEVICE auto + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, + Array const& frg_input) { + using ConvertInput = NumericArrayConverter; + ConvertInput convert_input{}; + + Tensor tC_rAux_frg = recast>(coalesce(tC_rAux)); // (EPI_V) + tC_rAux_frg(epi_v) = convert_input(frg_input); + + return frg_input; + } + + CUTLASS_DEVICE void + postreduce(int epi_m, int epi_n, int store_iteration, bool issue_smem_store) { + if constexpr (EnableNullptr) { + if (params_ptr->is_nullptr) { + return; + } + } + + using RLayoutR2S = decltype(cute::layout(TiledR2S{}.get_slice(0).retile_S(RTensor{}))); + Tensor tRS_rAux = make_tensor(tC_rAux.data(), RLayoutR2S{}); // (R2S,R2S_M,R2S_N) + + if (issue_smem_store) { + int store_pipe_index = store_iteration % Stages; + copy(tiled_r2s, tRS_rAux, tRS_sAux(_,_,_,store_pipe_index)); + } + } + + CUTLASS_DEVICE void + tma_store(int epi_m, int epi_n, int store_iteration, bool issue_tma_store) { + if constexpr (EnableNullptr) { + if (params_ptr->is_nullptr) { + return; + } + } + + if (issue_tma_store) { + // Issue the TMA store + int store_pipe_index = store_iteration % Stages; + copy(params_ptr->tma_store_aux, bSG_sAux(_,_,_,store_pipe_index), bSG_gAux(_,_,_,epi_m,epi_n)); + } + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + Tensor mAux = params_ptr->tma_store_aux.get_tma_tensor(make_shape(M,N,L)); // (M,N,L) + Tensor gAux = local_tile(mAux, take<0,2>(args.tile_shape_mnk), make_coord(m,n,l)); // (CTA_M,CTA_N) + + Tensor tC_gAux = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + gAux, args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tC_rAux = make_tensor(take<0,3>(shape(tC_gAux))); // (CPY,CPY_M,CPY_N) + + Tensor sAux_epi = cute::as_position_independent_swizzle_tensor( + make_tensor(make_smem_ptr(smem_aux), SmemLayout{})); // (EPI_TILE_M,EPI_TILE_N,PIPE) + Tensor gAux_epi = flat_divide(gAux, args.epi_tile); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N) + + auto tiled_r2s = conditional_return( + make_tiled_copy_S(Copy_Atom{}, args.tiled_copy), + make_tiled_copy_D(Copy_Atom{}, args.tiled_copy) + ); + auto tRS_sAux = tiled_r2s.get_slice(args.thread_idx).partition_D(sAux_epi); // (R2S,R2S_M,R2S_N,PIPE) + + ThrCopy thrblk_s2g = params_ptr->tma_store_aux.get_slice(_0{}); + Tensor bSG_sAux = thrblk_s2g.partition_S(sAux_epi); // (TMA,TMA_M,TMA_N,PIPE) + Tensor bSG_gAux = thrblk_s2g.partition_D(gAux_epi); // (TMA,TMA_M,TMA_N,EPI_M,EPI_N) + + return ConsumerStoreCallbacks( + cute::move(tC_rAux), + tiled_r2s, + cute::move(tRS_sAux), + cute::move(bSG_sAux), + cute::move(bSG_gAux), + params_ptr); + } +}; + +template < + class Element, + class EpilogueTile, // Unused + FloatRoundStyle RoundStyle, + class LayoutOrStrideMNL, + class SmemLayoutAtom, // Unused + class CopyOpR2S, // Unused + int Alignment, + bool EnableNullptr +> +struct Sm90AuxStore< + 0, EpilogueTile, Element, RoundStyle, LayoutOrStrideMNL, + SmemLayoutAtom, CopyOpR2S, Alignment, EnableNullptr +> { + using ElementAux = Element; + using StrideMNL = cutlass::gemm::TagToStrideC_t; + + struct SharedStorage { }; + + struct Arguments { + Element* ptr_aux = nullptr; + StrideMNL dAux = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return cutlass::Status::kSuccess; + } + + CUTLASS_HOST_DEVICE + Sm90AuxStore() { } + + CUTLASS_HOST_DEVICE + Sm90AuxStore(Params const& params, SharedStorage const& shared_storage) + : params_ptr(¶ms) { } + + Params const* params_ptr; + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template< + class GTensorR2G, + class RTensor, + class CTensorR2G, + class ProblemShapeMNL + > + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks( + GTensorR2G&& tC_gAux, + RTensor&& tC_rAux, + CTensorR2G&& tC_cAux, + ProblemShapeMNL problem_shape_mnl, + Params const* params_ptr) + : tC_gAux(cute::forward(tC_gAux)), + tC_rAux(cute::forward(tC_rAux)), + tC_cAux(cute::forward(tC_cAux)), + problem_shape_mnl(problem_shape_mnl), + params_ptr(params_ptr) {} + + GTensorR2G tC_gAux; + RTensor tC_rAux; + CTensorR2G tC_cAux; + ProblemShapeMNL problem_shape_mnl; + Params const* params_ptr; + + template + CUTLASS_DEVICE auto + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, + Array const& frg_input) { + using ConvertInput = NumericArrayConverter; + ConvertInput convert_input{}; + + Tensor tC_rAux_frg = recast>(coalesce(tC_rAux)); + tC_rAux_frg(epi_v) = convert_input(frg_input); + + return frg_input; + } + + CUTLASS_DEVICE void + end_loop(int epi_m, int epi_n) { + if constexpr (EnableNullptr) { + if (params_ptr->ptr_aux == nullptr) { + return; + } + } + + constexpr auto MCL = decltype(max_common_layout(tC_gAux(_,_,_,_0{},_0{}), tC_rAux)){}; + constexpr int V = cute::min(Alignment, size(MCL)); + + Tensor tC_cAux_mn = tC_cAux(_,_,_,epi_m,epi_n); + Tensor tC_cAux_vec = tensor<1>(zipped_divide(coalesce(tC_cAux_mn), MCL.compose(Int{}))); + + Tensor tC_gAux_vec = recast>(coalesce(tC_gAux(_,_,_,epi_m,epi_n))); + Tensor tC_rAux_vec = recast>(coalesce(tC_rAux)); + + auto pred_fn = [&] (auto const&... coords) { + return elem_less(tC_cAux_vec(coords...), problem_shape_mnl); + }; + + copy_if(pred_fn, tC_rAux_vec, tC_gAux_vec); + } + }; + + template < + bool ReferenceSrc, + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + + auto problem_shape_mnl = make_shape(M,N,L); + + // Gmem Tensor + Tensor mAux = make_tensor( + make_gmem_ptr(params_ptr->ptr_aux), make_shape(M,N,L), params_ptr->dAux + ); + Tensor tC_gAux = sm90_partition_for_epilogue( + mAux, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); + + // Register Tensor + Tensor tC_rAux = make_tensor(take<0,3>(shape(tC_gAux))); + + // Predication support + Tensor coordAux = make_identity_tensor(shape(mAux)); + Tensor tC_cAux = sm90_partition_for_epilogue( + coordAux, args.tile_shape_mnk, args.tile_coord_mnkl, args.epi_tile, args.tiled_copy, args.thread_idx); + + return ConsumerStoreCallbacks( + cute::move(tC_gAux), + cute::move(tC_rAux), + cute::move(tC_cAux), + problem_shape_mnl, + params_ptr + ); + + } + +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Reduction Store Operations +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Scalar reduction +template < + template class RegReduceFn, + template class GmemReduceFn, + class ElementOutput, + class ElementCompute, + FloatRoundStyle RoundStyle, + class StrideMNL = Stride<_0,_0,_0>, + bool EnableNullptr = true // Noop on nullptr params +> +struct Sm90ScalarReduction { +private: + static_assert(is_static_v(StrideMNL{}))>); // batch stride can be dynamic or static + static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_0>{}); + static constexpr bool IsAtomic = is_atomic>::value; + static_assert(IsAtomic, "non-atomic scalar reduction not supported yet"); + +public: + struct SharedStorage { }; + + struct Arguments { + ElementOutput* ptr_scalar = nullptr; + ElementCompute reduction_identity = ElementCompute(0); + StrideMNL dScalar = {}; + }; + + using Params = Arguments; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return args; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + #if !defined(CUTLASS_SKIP_REDUCTION_INIT) + if constexpr (IsAtomic) { + auto problem_shape_mnkl = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_mnkl; + Layout mScalar_layout = make_layout(make_shape(M,N,L), args.dScalar); + if (args.ptr_scalar != nullptr) { + return fill_workspace(args.ptr_scalar, ElementOutput(args.reduction_identity), cosize(mScalar_layout), stream, cuda_adapter); + } + } + #endif + + return cutlass::Status::kSuccess; + } + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_HOST_DEVICE + Sm90ScalarReduction() { } + + CUTLASS_HOST_DEVICE + Sm90ScalarReduction(Params const& params, SharedStorage const& shared_storage) + : params(params) { } + + Params const params; + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks( + int l_coord, + CTensor tCcScalar, + ThrResidue residue_tCcScalar, + Params const& params) + : scalar(params.reduction_identity), + l_coord(l_coord), + tCcScalar(tCcScalar), + residue_tCcScalar(residue_tCcScalar), + params(params) {} + + ElementCompute scalar; + int l_coord; + CTensor tCcScalar; // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + ThrResidue residue_tCcScalar; + Params params; + + template + CUTLASS_DEVICE auto + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, + Array const& frg_input) { + if constexpr (EnableNullptr) { + if (params.ptr_scalar == nullptr) { + return frg_input; + } + } + + using ConvertInput = NumericArrayConverter; + using ReduceInput = RegReduceFn; + ConvertInput convert_input{}; + ReduceInput reduce_input{}; + + Array frg_I = convert_input(frg_input); + Tensor tCcScalar_mn = tCcScalar(_,_,_,epi_m,epi_n); + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + if (elem_less(tCcScalar_mn(epi_v * FragmentSize + i), residue_tCcScalar)) { + scalar = reduce_input(scalar, frg_I[i]); + } + } + + return frg_input; + } + + CUTLASS_DEVICE void + end() { + if constexpr (EnableNullptr) { + if (params.ptr_scalar == nullptr) { + return; + } + } + + using ConvertI = NumericConverter; + using ReduceInput = GmemReduceFn; + + ConvertI convert_I{}; + ReduceInput reduce_input{}; + + ElementOutput* ptr_scalar = params.ptr_scalar + l_coord * get<2>(params.dScalar); + reduce_input(ptr_scalar, convert_I(scalar)); + } + + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + return ConsumerStoreCallbacks( + get<3>(args.tile_coord_mnkl), args.tCcD, args.residue_tCcD, params); + } + +}; + + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Row vector reduction +template < + template class RegReduceFn, + template class ShuffleReduceFn, + template class GmemReduceFn, + int Stages, + class CtaTileShapeMNK, + class ElementOutput, + class ElementCompute, + FloatRoundStyle RoundStyle, + class StrideMNL = Stride<_0,_1,_0>, + int Alignment = 128 / sizeof_bits_v, + bool EnableNullptr = true, // Noop on nullptr params + // If this is false, ptr_row is assumed to point to a compact n-major (ceil_div(M,CTA_M), round_nearest(N,CTA_N), L) + // tensor of ElementCompute. It is the user's responsibility to reduce this to a (N, L) tensor of ElementOutput + bool FinalReduction = true, + // False means skip OOB predication if OOB inputs are known to be the reduction identity + bool VisitCheckOOB = true, + // Indicate the parameter order when calling RegReduceFn + // Seq length equals the number of RegReduceFn parameters + // No.0 represents tCrRow; No.1 and subsequent numbers sequentially represent frg_inputs in `visit` + class RegReduceSeq = cute::seq<0, 1> +> +struct Sm90RowReduction { +private: + static_assert(Stages == 0, "Smem usage not supported yet"); + static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); + static_assert(is_static_v(StrideMNL{}))>); // batch stride can be dynamic or static + static_assert(take<0,2>(StrideMNL{}) == Stride<_0,_1>{}); + static constexpr bool IsAtomic = is_atomic>::value; + static_assert(not (IsAtomic && not FinalReduction), "atomic reduction must be final"); + +public: + struct SharedStorage { }; + + struct Arguments { + void* ptr_row = nullptr; // ElementOutput* if FinalReduction, else ElementCompute* + ElementCompute reduction_identity = 0; + StrideMNL dRow = {}; + }; + + struct Params { + void* ptr_row = nullptr; + ElementCompute reduction_identity = 0; + StrideMNL dRow = {}; + ElementCompute* reduction_buffer = nullptr; + int* tile_counters = nullptr; + }; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + ElementCompute* reduction_buffer; + int* tile_counters = nullptr; + if constexpr (IsAtomic) { + reduction_buffer = nullptr; + } + else if constexpr (FinalReduction) { + auto problem_shape_mnkl = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_mnkl; + auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; + size_t tile_counters_offset = product(ceil_div(make_shape(size<>(M), size<>(N), L), make_shape(tile_M, tile_N))) * tile_N * sizeof(ElementCompute); + tile_counters_offset = round_nearest(tile_counters_offset, MinWorkspaceAlignment); + + reduction_buffer = reinterpret_cast(workspace); + tile_counters = reinterpret_cast(reinterpret_cast(workspace) + tile_counters_offset); + } + else { + reduction_buffer = reinterpret_cast(args.ptr_row); + } + + return { + args.ptr_row, + args.reduction_identity, + args.dRow, + reduction_buffer, + tile_counters + }; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + if constexpr (IsAtomic || not FinalReduction) { + return 0; + } + + size_t workspace_size = 0; + auto problem_shape_mnkl = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_mnkl; + auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; + // Increment by size of reduction buffer + workspace_size += product(ceil_div(make_shape(size<>(M),size<>(N),L), make_shape(tile_M, tile_N))) * tile_N * sizeof(ElementCompute); + // Align and increment by size of tile counters + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + workspace_size += cute::ceil_div(size<>(N), tile_N) * sizeof(int); + return workspace_size; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + if constexpr (IsAtomic) { + auto problem_shape_mnkl = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_mnkl; + Layout mRow_layout = make_layout(make_shape(size<>(M),size<>(N),size<>(L)), args.dRow); + if (args.ptr_row != nullptr) { + return fill_workspace(args.ptr_row, ElementOutput(args.reduction_identity), cosize(mRow_layout), stream, cuda_adapter); + } + return Status::kSuccess; + } + else if constexpr (FinalReduction) { + auto problem_shape_mnkl = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_mnkl; + auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; + size_t tile_counters_offset = product(ceil_div(make_shape(size<>(M),size<>(N),L), make_shape(tile_M, tile_N))) * tile_N * sizeof(ElementCompute); + tile_counters_offset = round_nearest(tile_counters_offset, MinWorkspaceAlignment); + + int* tile_counters = reinterpret_cast(reinterpret_cast(workspace) + tile_counters_offset); + size_t tile_counters_size = cute::ceil_div(size<>(N), tile_N) * sizeof(int); + return zero_workspace(tile_counters, tile_counters_size, stream, cuda_adapter); + } + else { + return Status::kSuccess; + } + } + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_HOST_DEVICE + Sm90RowReduction() { } + + CUTLASS_HOST_DEVICE + Sm90RowReduction(Params const& params, SharedStorage const& shared_storage) + : params(params) { } + + Params params; + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks(ArgsTuple&& args_tuple, Params const& params) + : args_tuple(cute::forward(args_tuple)), + params(params) {} + + ArgsTuple args_tuple; + Params const& params; + bool do_final_reduction = false; + + template + CUTLASS_DEVICE auto + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, + Array const&... frg_inputs) { + if constexpr (EnableNullptr) { + if (params.ptr_row == nullptr) { + return cute::get<0>(cute::make_tuple(frg_inputs...)); + } + } + + auto& [ref_src, tCrRow, tCcRow, gRow_l, cRow, gBuf_ml, sBuf_layout, + lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, + tile_coord_mnkl, residue_cRow, residue_tCcRow, epi_tile, tiled_copy, thread_idx] = args_tuple; + Tensor tCrRow_mn = tCrRow(_,_,_,epi_m,epi_n); + Tensor tCcRow_mn = tCcRow(_,_,_,epi_m,epi_n); + + if constexpr (VisitCheckOOB) { + using ReduceInput = RegReduceFn; + ReduceInput reduce_input{}; + + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + if (elem_less(tCcRow_mn(epi_v * FragmentSize + i), residue_tCcRow)) { + ElementCompute& tCrRow_vmn = tCrRow_mn(epi_v * FragmentSize + i); + tCrRow_vmn = transform_apply(cute::make_tuple(frg_inputs...), + [&] (auto&& frg_input) { + return ElementCompute(frg_input[i]); + }, + [&] (auto&&... cvt_frg_inputs) { + auto frg_compute_tuple = cute::make_tuple(tCrRow_vmn, cvt_frg_inputs...); + return cute::detail::apply(frg_compute_tuple, reduce_input, RegReduceSeq{}); + }); + } + } + } + else { + constexpr int RegFragSize = cute::max(1, static_cast(sizeof(uint32_t) / sizeof(ElementCompute))); + using ReduceInput = RegReduceFn>; + ReduceInput reduce_input{}; + Tensor tCrRow_mn_frg = recast>(tCrRow_mn); + + constexpr int RegFragArraySize = FragmentSize / RegFragSize; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < RegFragArraySize; ++i) { + Array& tCrRow_vmn_frg = tCrRow_mn_frg(epi_v * RegFragArraySize + i); + tCrRow_vmn_frg = transform_apply(cute::make_tuple(frg_inputs...), + [&] (auto&& frg_input) { + using ElementInput = typename cute::remove_cvref_t::Element; + using ConvertInput = NumericArrayConverter; + using RegFragArr = Array, RegFragArraySize>; + ConvertInput convert_input{}; + return convert_input(reinterpret_cast(frg_input)[i]); + }, + [&] (auto&&... cvt_frg_inputs) { + auto frg_compute_tuple = cute::make_tuple(tCrRow_vmn_frg, cvt_frg_inputs...); + return cute::detail::apply(frg_compute_tuple, reduce_input, RegReduceSeq{}); + }); + } + } + return cute::get<0>(cute::make_tuple(frg_inputs...)); + } + + template + CUTLASS_DEVICE void + reduce(STensor&& smem_buffer, SyncFn const& sync_fn, int epi_m, int epi_n, bool is_last_iteration, VTensor visit_results) { + if (not is_last_iteration) { + return; + } + + auto& [ref_src, tCrRow, tCcRow, gRow_l, cRow, gBuf_ml, sBuf_layout, + lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, + tile_coord_mnkl, residue_cRow, residue_tCcRow, epi_tile, tiled_copy, thread_idx] = args_tuple; + auto [m, n, k, l] = tile_coord_mnkl; + constexpr bool ReferenceSrc = decltype(ref_src)::value; + if constexpr (EnableNullptr) { + if (params.ptr_row == nullptr) { + return; + } + } + + // fully OOB CTA in partially OOB cluster + if (not elem_less(cRow(_0{},_0{}), residue_cRow)) { + return; + } + + int lane_m = get<0>(lane_mn); + [[maybe_unused]] bool is_reduced_lane = lane_m == 0; + + // + // 1. Warp shuffle reduction + // + using FragmentShuffle = Array; + Tensor tCrRow_frg = recast(filter(tCrRow)); + using ReduceShuffle = ShuffleReduceFn; + ReduceShuffle reduce_shuffle{}; + + auto FrgSizePerLaneM = size(tCrRow_frg) / size<0>(lane_layout_MN); + constexpr bool SwapShuffle = FrgSizePerLaneM > 0; + + // + // Swap Shuffle + // + // The normal way to reduction among threads: + // use shuffle to let *** the first half of threads *** have *** whole data *** from the second half of threads. + // After each step of reduction, a half of threads won't work in the following steps. + // That is, as the reduction progresses, the efficiency of shuffle & reduction instructions gradually change from 1/2, 1/4 to 1/32 (the worst case). + // + // To overcome this shortcoming, for a NxN matrix to be reduced among N threads as a 1XN vectors, + // we use swap & shuffle aiming to let *** each half of threads *** have *** a half of data *** from the other half of threads. + // After reduction, each half of threads should deal with a (N/2)x(N/2) sub-matrix independently in the following step. + // We can recursively do this until the problem size is 1. + // + if constexpr (SwapShuffle) { // for a NxN matrix to be reduced among N threads as a 1XN vectors + Tensor tCrRow_frg_ = logical_divide(tCrRow_frg, FrgSizePerLaneM); // (FrgSizePerLaneM, M) + CUTLASS_PRAGMA_UNROLL + for (int m = size<1>(tCrRow_frg_) / 2; m > 0; m /= 2) { + CUTLASS_PRAGMA_UNROLL + for (int r = 0; r < m; ++r) { + auto frg_A = tCrRow_frg_(_,r); + auto frg_B = tCrRow_frg_(_,r + m); + CUTLASS_PRAGMA_UNROLL + for (int v = 0; v < size(frg_A); ++v) { + // Step1: swap + if (not (lane_m & m)) { // the first half of threads swap fragments from the first half of data to the second + cutlass::swap(frg_A(v), frg_B(v)); + } + + // Step2: shuffle + uint64_t frg_shfl = reinterpret_cast(frg_A(v)); + // each half of threads get a half of data from the other half of threads + frg_shfl = __shfl_xor_sync(0xFFFFFFFF, frg_shfl, lane_layout_MN(m, _0{})); + + // Step3: reduction + frg_A(v) = reduce_shuffle(frg_B(v), reinterpret_cast(frg_shfl)); + } + } + } + } + else { + CUTLASS_PRAGMA_UNROLL + for (int reduction_rows = size<0>(lane_layout_MN) / 2; reduction_rows > 0; reduction_rows /= 2) { + CUTLASS_PRAGMA_UNROLL + for (int frg_idx = 0; frg_idx < size(tCrRow_frg); ++frg_idx) { + uint64_t frg_shfl = reinterpret_cast(tCrRow_frg(frg_idx)); + frg_shfl = __shfl_down_sync(0xFFFFFFFF, frg_shfl, lane_layout_MN(reduction_rows, _0{})); + tCrRow_frg(frg_idx) = reduce_shuffle(tCrRow_frg(frg_idx), reinterpret_cast(frg_shfl)); + } + } + } + + // + // 2. Atomic reduction + // + if constexpr (IsAtomic) { + // Filter so we don't issue redunant copies over stride-0 modes + Tensor tCrRow_flt = filter_zeros(tCrRow); + Tensor tCcRow_flt = make_tensor(tCcRow.data(), make_layout(tCrRow_flt.shape(), tCcRow.stride())); + auto FltFrgSizePerLaneM = size(tCrRow_flt) / size<0>(lane_layout_MN); + + Tensor tCgRow = sm90_partition_for_epilogue(gRow_l(_,_,l), epi_tile, tiled_copy, thread_idx); + Tensor tCgRow_flt = filter_zeros(tCgRow); + // NOTE: atomic reduction is performed in the output type + using ConvertOutput = NumericConverter; + using ReduceOutput = GmemReduceFn; + ConvertOutput convert_output{}; + ReduceOutput reduce_output{}; + + if constexpr (SwapShuffle) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FltFrgSizePerLaneM; ++i) { + int idx = lane_m * FltFrgSizePerLaneM + i; + // Only care about OOB for N mode + if (get<1>(tCcRow_flt(idx)) < get<1>(residue_tCcRow)) { + reduce_output(&tCgRow_flt(idx), convert_output(tCrRow_flt(i))); + } + } + } + else { + if (is_reduced_lane) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tCrRow_flt); ++i) { + if (elem_less(tCcRow_flt(i), residue_tCcRow)) { + reduce_output(&tCgRow_flt(i), convert_output(tCrRow_flt(i))); + } + } + } + } + sync_fn(); + } + + // + // 2. One warp in M, skip threadblock smem reduction + // + else if constexpr (decltype(size<0>(warp_layout_MN))::value <= 1) { + // Dump warp reduction to gmem workspace + using ElementGmem = cute::conditional_t; + Tensor tCgBuf = sm90_partition_for_epilogue(gBuf_ml(_,_,m,l), epi_tile, tiled_copy, thread_idx); + + if constexpr (SwapShuffle) { + Tensor tCrRow_flt = filter(tCrRow); + Tensor tCgBuf_flt = recast(filter(tCgBuf)); + auto FltFrgSizePerLaneM = size(tCrRow_flt) / size<0>(lane_layout_MN); + Tensor tCgBuf_flt_ = logical_divide(tCgBuf_flt, FltFrgSizePerLaneM); // (FltFrgSizePerLaneM, M) + Tensor tCrRow_flt_ = logical_divide(tCrRow_flt, FltFrgSizePerLaneM); // (FltFrgSizePerLaneM, M) + copy_aligned(tCrRow_flt_(_,_0{}), tCgBuf_flt_(_,lane_m)); + } + else { + if (is_reduced_lane) { + copy_aligned(tCrRow, recast(tCgBuf)); + } + } + sync_fn(); + } + + // + // 2. Multiple warps in M, do threadblock smem reduction + // + else { + Tensor sBuf = make_tensor(make_smem_ptr(raw_pointer_cast(smem_buffer.data())), sBuf_layout); + static_assert(decltype(cosize(sBuf.layout()))::value * sizeof(ElementCompute) <= + decltype(cosize(smem_buffer.layout()))::value * sizeof(typename remove_cvref_t::value_type), + "smem reduction buffer not large enough, use a larger epilogue tile"); + sync_fn(); + + // Dump warp reduction to smem workspace + Tensor tCsBuf = sm90_partition_for_epilogue(sBuf(_,_,get<0>(warp_mn)), epi_tile, tiled_copy, thread_idx); + + if constexpr (SwapShuffle) { + Tensor tCrRow_flt = filter(tCrRow); + Tensor tCsBuf_flt = filter(tCsBuf); + auto FltFrgSizePerLaneM = size(tCrRow_flt) / size<0>(lane_layout_MN); + Tensor tCsBuf_flt_ = logical_divide(tCsBuf_flt, FltFrgSizePerLaneM); // (FltFrgSizePerLaneM, M) + Tensor tCrRow_flt_ = logical_divide(tCrRow_flt, FltFrgSizePerLaneM); // (FltFrgSizePerLaneM, M) + copy_aligned(tCrRow_flt_(_,_0{}), tCsBuf_flt_(_,lane_m)); + } + else { + if (is_reduced_lane) { + copy_aligned(tCrRow, tCsBuf); + } + } + sync_fn(); + + constexpr int SmemFragSize = cute::max(size_t{1}, sizeof(uint32_t) / sizeof(ElementCompute)); + using FragmentSmem = Array; + using VectorSmem = uint_bit_t>; + using ReduceSmem = GmemReduceFn; + ReduceSmem reduce_smem{}; + + Tensor sBuf_frg = recast(filter_zeros(sBuf)); + Tensor sBuf_vec = recast(filter_zeros(sBuf)); + constexpr int FragsPerRow = decltype(size<1>(sBuf_frg))::value; + + constexpr int RowNum = decltype(size<0>(warp_layout_MN))::value; + using FragmentSmemArray = Array; + + // Do the threadblock smem reduction + using VectorGmem = cute::conditional_t; + Tensor gBuf_vec = recast(filter(gBuf_ml(_,_,m,l))); + CUTLASS_PRAGMA_UNROLL + for (int frg_idx = thread_idx; frg_idx < FragsPerRow; frg_idx += size(tiled_copy)) { + FragmentSmemArray frg_smem; + + CUTLASS_PRAGMA_UNROLL + for (int reduction_rows = 0; reduction_rows < RowNum; ++reduction_rows) { + int FragsCurrRows = reduction_rows * FragsPerRow; + frg_smem[reduction_rows] = sBuf_frg(FragsCurrRows + frg_idx); + } + + CUTLASS_PRAGMA_UNROLL + for (int reduction_rows = RowNum / 2; reduction_rows > 0; reduction_rows /= 2) { + CUTLASS_PRAGMA_UNROLL + for (int row_idx = 0; row_idx < reduction_rows; ++row_idx) { + frg_smem[row_idx] = reduce_smem(frg_smem[row_idx], frg_smem[row_idx + reduction_rows]); + } + } + gBuf_vec(frg_idx) = reinterpret_cast(frg_smem[0]); + } + sync_fn(); + } + + // + // 3. Increment atomic counters to signal final gmem reduction + // + if constexpr (not IsAtomic && FinalReduction) { + // Ensure gmem writes are visible to other threads before incrementing counter + __threadfence(); + sync_fn(); + // Collective thread 0 increments atomic tile counter and copies value to smem + int* prev_tile_count = reinterpret_cast(raw_pointer_cast(smem_buffer.data())); + if (thread_idx == 0) { + *prev_tile_count = atomicAdd(¶ms.tile_counters[n], 1); + } + sync_fn(); + // Broadcast tile count to other threads in CTA and determine final reduction status + do_final_reduction = *prev_tile_count == size<2>(gBuf_ml) * size<3>(gBuf_ml) - 1; + sync_fn(); + } + } + + CUTLASS_DEVICE void + end() { + // + // 4. Do final gmem reduction if necessary + // + if constexpr (not IsAtomic && FinalReduction) { + if (not do_final_reduction) { + return; + } + + auto& [ref_src, tCrRow, tCcRow, gRow_l, cRow, gBuf_ml, sBuf_layout, + lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, + tile_coord_mnkl, residue_cRow, residue_tCcRow, epi_tile, tiled_copy, thread_idx] = args_tuple; + + using ReduceOutput = GmemReduceFn; + using ConvertOutput = NumericConverter; + ReduceOutput reduce_output{}; + ConvertOutput convert_output{}; + + // Reduction over batches + if (size<2>(stride(gRow_l)) == 0) { + CUTLASS_PRAGMA_NO_UNROLL + for (int n = thread_idx; n < size<1>(gBuf_ml); n += size(tiled_copy)) { + Tensor tRgBuf_ml = gBuf_ml(_0{},n,_,_); + ElementCompute output = tRgBuf_ml(_0{}); + CUTLASS_PRAGMA_NO_UNROLL + for (int ml = 1; ml < size(tRgBuf_ml); ++ml) { + output = reduce_output(output, tRgBuf_ml(ml)); + } + if (elem_less(cRow(_0{},n), residue_cRow)) { + gRow_l(_0{},n,_0{}) = convert_output(output); + } + } + } + // No reduction over batches + else { + CUTLASS_PRAGMA_NO_UNROLL + for (int n = thread_idx; n < size<1>(gBuf_ml); n += size(tiled_copy)) { + bool do_store = elem_less(cRow(_0{},n), residue_cRow); + CUTLASS_PRAGMA_NO_UNROLL + for (int l = 0; l < size<3>(gBuf_ml); ++l) { + Tensor tRgBuf_m = gBuf_ml(_0{},n,_,l); + ElementCompute output = tRgBuf_m(_0{}); + CUTLASS_PRAGMA_NO_UNROLL + for (int m = 1; m < size(tRgBuf_m); ++m) { + output = reduce_output(output, tRgBuf_m(m)); + } + if (do_store) { + gRow_l(_0{},n,l) = convert_output(output); + } + } + } + } + + } + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + Layout ref_layout_MN = [&] () { + if constexpr (ReferenceSrc) { return get<0>(args.tiled_copy.get_layoutS_MN()); } + else { return get<0>(args.tiled_copy.get_layoutD_MN()); } + }(); // tile_mn -> tv_idx + + // Get the MN layout + coord of lanes to determine shuffle reduction iterations + using _W = Int; + Layout tv2lane = Layout,_W,_1>,Stride<_1,_0,_0>>{}; // tv_idx -> lane_idx + Layout ref2lane = composition(tv2lane, ref_layout_MN); // tile_mn -> lane_idx + Layout lane_layout_MN = make_layout(filter(get<0>(ref2lane)), filter(get<1>(ref2lane))); // lane_mn -> lane_idx + Layout inv_lane_layout_MN = right_inverse(lane_layout_MN); // lane_idx -> lane_mn + int lane_idx = canonical_lane_idx(); + auto lane_mn = idx2crd(inv_lane_layout_MN(lane_idx), shape(lane_layout_MN)); + + // Get the MN layout + coord of warps to determine smem reduction iterations + Layout tv2warp = Layout,_W,_1>,Stride<_0,_1,_0>>{}; // tv_idx -> warp_idx + Layout ref2warp = composition(tv2warp, ref_layout_MN); // tile_mn -> warp_idx + Layout warp_layout_MN = make_layout(filter(get<0>(ref2warp)), filter(get<1>(ref2warp))); // warp_mn -> warp_idx + Layout inv_warp_layout_MN = right_inverse(warp_layout_MN); // warp_idx -> warp_mn + + int warp_idx = args.thread_idx / NumThreadsPerWarp; + auto warp_mn = idx2crd(inv_warp_layout_MN(warp_idx), shape(warp_layout_MN)); + + // Partition output gmem and register tensors + auto [tile_M, tile_N, tile_K] = args.tile_shape_mnk; + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + + Tensor mRow = make_tensor(make_gmem_ptr(params.ptr_row), make_shape(M,N,L), params.dRow); // (M,N,L) + Tensor gRow_l = local_tile(mRow, take<0,2>(args.tile_shape_mnk), make_coord(m,n,_)); // (CTA_M,CTA_N,L) + Tensor tCgRow = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + gRow_l(_,_,l), args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tCrRow = make_tensor_like(tCgRow); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + + fill(tCrRow, params.reduction_identity); + + // Partition gmem+smem reduction buffer tensors + Layout gBuf_layout = make_layout(take<0,2>(args.tile_shape_mnk), make_stride(_0{}, _1{})); + auto block_shape = ceil_div(make_shape(M,N,L), shape(gBuf_layout)); // (M_CNT, N_CNT, L_CNT) + + // Let the M_CNT (the num of partial reduction results) become the outer mode + Layout block_layout = make_layout(block_shape, make_stride(get<1>(block_shape), _1{}, get<0>(block_shape) * get<1>(block_shape))); + Layout mBuf_layout = blocked_product(gBuf_layout, block_layout); + Tensor mBuf = make_tensor(make_gmem_ptr(params.reduction_buffer), mBuf_layout); // (ceil_M,ceil_N,L) + Tensor gBuf_ml = local_tile(mBuf, take<0,2>(args.tile_shape_mnk), make_coord(_,n,_)); // (CTA_M,CTA_N,REST_M,L) + Layout sBuf_layout = blocked_product(gBuf_layout, // (CTA_M,CTA_N,WARPS_M) + make_layout(make_shape(_1{},_1{},size<0>(warp_layout_MN)))); + + auto args_tuple = make_tuple( + bool_constant{}, cute::move(tCrRow), args.tCcD, gRow_l, args.cD, gBuf_ml, sBuf_layout, + lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, + args.tile_coord_mnkl, args.residue_cD, args.residue_tCcD, args.epi_tile, args.tiled_copy, args.thread_idx); + return ConsumerStoreCallbacks(cute::move(args_tuple), params); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Col vector reduction +template < + template class RegReduceFn, + template class ShuffleReduceFn, + template class GmemReduceFn, + int Stages, + class CtaTileShapeMNK, + class ElementOutput, + class ElementCompute, + FloatRoundStyle RoundStyle, + class StrideMNL = Stride<_1,_0,_0>, + int Alignment = 128 / sizeof_bits_v, + bool EnableNullptr = true, // Noop on nullptr params + // If this is false, ptr_col is assumed to point to a compact m-major (round_nearest(M,CTA_M), ceil_div(N,CTA_N), L) + // tensor of ElementCompute. It is the user's responsibility to reduce this to a (M, L) tensor of ElementOutput + bool FinalReduction = true, + // False means skip OOB predication if OOB inputs are known to be the reduction identity + bool VisitCheckOOB = true +> +struct Sm90ColReduction { +private: + static_assert(Stages == 0, "Smem usage not supported yet"); + static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); + static_assert(is_static_v(StrideMNL{}))>); // batch stride can be dynamic or static + static_assert(take<0,2>(StrideMNL{}) == Stride<_1,_0>{}); + static constexpr bool IsAtomic = is_atomic>::value; + static_assert(not (IsAtomic && not FinalReduction), "atomic reduction must be final"); + +public: + struct SharedStorage { }; + + struct Arguments { + void* ptr_col = nullptr; // ElementOutput* if FinalReduction, else ElementCompute* + ElementCompute reduction_identity = 0; + StrideMNL dCol = {}; + }; + + struct Params { + void* ptr_col = nullptr; + ElementCompute reduction_identity = 0; + StrideMNL dCol = {}; + ElementCompute* reduction_buffer = nullptr; + int* tile_counters = nullptr; + }; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + ElementCompute* reduction_buffer; + int* tile_counters = nullptr; + if constexpr (IsAtomic) { + reduction_buffer = nullptr; + } + else if constexpr (FinalReduction) { + auto problem_shape_mnkl = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_mnkl; + auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; + size_t tile_counters_offset = product(ceil_div(make_shape(M,N,L), make_shape(tile_M, tile_N))) * tile_M * sizeof(ElementCompute); + tile_counters_offset = round_nearest(tile_counters_offset, MinWorkspaceAlignment); + + reduction_buffer = reinterpret_cast(workspace); + tile_counters = reinterpret_cast(reinterpret_cast(workspace) + tile_counters_offset); + } + else { + reduction_buffer = reinterpret_cast(args.ptr_col); + } + + return { + args.ptr_col, + args.reduction_identity, + args.dCol, + reduction_buffer, + tile_counters + }; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return true; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + if constexpr (IsAtomic || not FinalReduction) { + return 0; + } + + size_t workspace_size = 0; + auto problem_shape_mnkl = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_mnkl; + auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; + + // Increment by size of reduction buffer + workspace_size += product(ceil_div(make_shape(M,N,L), make_shape(tile_M, tile_N))) * tile_M * sizeof(ElementCompute); + // Align and increment by size of tile counters + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + workspace_size += cute::ceil_div(M, tile_M) * sizeof(int); + + return workspace_size; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + if constexpr (IsAtomic) { + auto problem_shape_mnkl = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_mnkl; + Layout mCol_layout = make_layout(make_shape(size<>(M),size<>(N),size<>(L)), args.dCol); + if (args.ptr_col != nullptr) { + return fill_workspace(args.ptr_col, ElementOutput(args.reduction_identity), cosize(mCol_layout), stream, cuda_adapter); + } + return Status::kSuccess; + } + else if constexpr (FinalReduction) { + auto problem_shape_mnkl = append<4>(problem_shape, 1); + auto [M, N, K, L] = problem_shape_mnkl; + auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; + size_t tile_counters_offset = product(ceil_div(make_shape(M,N,L), make_shape(tile_M, tile_N))) * tile_M * sizeof(ElementCompute); + tile_counters_offset = round_nearest(tile_counters_offset, MinWorkspaceAlignment); + + int* tile_counters = reinterpret_cast(reinterpret_cast(workspace) + tile_counters_offset); + size_t tile_counters_size = cute::ceil_div(M, tile_M) * sizeof(int); + return zero_workspace(tile_counters, tile_counters_size, stream, cuda_adapter); + } + else { + return Status::kSuccess; + } + } + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_HOST_DEVICE + Sm90ColReduction() { } + + CUTLASS_HOST_DEVICE + Sm90ColReduction(Params const& params, SharedStorage const& shared_storage) + : params(params) { } + + Params params; + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks(ArgsTuple&& args_tuple, Params const& params) + : args_tuple(cute::forward(args_tuple)), + params(params) {} + + ArgsTuple args_tuple; + Params const& params; + bool do_final_reduction = false; + + template + CUTLASS_DEVICE auto + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, + Array const& frg_input) { + if constexpr (EnableNullptr) { + if (params.ptr_col == nullptr) { + return frg_input; + } + } + + auto& [ref_src, tCrCol, tCcCol, gCol_l, cCol, gBuf_nl, sBuf_layout, + lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, + tile_coord_mnkl, residue_cCol, residue_tCcCol, epi_tile, tiled_copy, thread_idx] = args_tuple; + Tensor tCrCol_mn = tCrCol(_,_,_,epi_m,epi_n); + Tensor tCcCol_mn = tCcCol(_,_,_,epi_m,epi_n); + + using ConvertInput = NumericArrayConverter; + using ReduceInput = RegReduceFn; + ConvertInput convert_input{}; + ReduceInput reduce_input{}; + + Array frg_I = convert_input(frg_input); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + if (!VisitCheckOOB || elem_less(tCcCol_mn(epi_v * FragmentSize + i), residue_tCcCol)) { + ElementCompute& tCrCol_vmn = tCrCol_mn(epi_v * FragmentSize + i); + tCrCol_vmn = reduce_input(tCrCol_vmn, frg_I[i]); + } + } + + return frg_input; + } + + template + CUTLASS_DEVICE void + reduce(STensor&& smem_buffer, SyncFn const& sync_fn, int epi_m, int epi_n, bool is_last_iteration, VTensor visit_results) { + if (not is_last_iteration) { + return; + } + + auto& [ref_src, tCrCol, tCcCol, gCol_l, cCol, gBuf_nl, sBuf_layout, + lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, + tile_coord_mnkl, residue_cCol, residue_tCcCol, epi_tile, tiled_copy, thread_idx] = args_tuple; + auto [m, n, k, l] = tile_coord_mnkl; + constexpr bool ReferenceSrc = decltype(ref_src)::value; + + // Runtime nullptr is noop + if constexpr (EnableNullptr) { + if (params.ptr_col == nullptr) { + return; + } + } + + // fully OOB CTA in partially OOB cluster + if (not elem_less(cCol(_0{},_0{}), residue_cCol)) { + return; + } + + // + // 1. Warp shuffle reduction + // + using FragmentShuffle = Array; + using ReduceShuffle = ShuffleReduceFn; + ReduceShuffle reduce_shuffle{}; + Tensor tCrCol_frg = recast(filter(tCrCol)); + CUTLASS_PRAGMA_UNROLL + for (int reduction_cols = size<1>(lane_layout_MN) / 2; reduction_cols > 0; reduction_cols /= 2) { + CUTLASS_PRAGMA_UNROLL + for (int frg_idx = 0; frg_idx < size(tCrCol_frg); ++frg_idx) { + uint64_t frg_shfl = reinterpret_cast(tCrCol_frg(frg_idx)); + frg_shfl = __shfl_down_sync(0xFFFFFFFF, frg_shfl, lane_layout_MN(_0{},reduction_cols)); + tCrCol_frg(frg_idx) = reduce_shuffle(tCrCol_frg(frg_idx), reinterpret_cast(frg_shfl)); + } + } + bool is_reduced_lane = get<1>(lane_mn) == 0; + + // + // 2. Atomic reduction + // + if constexpr (IsAtomic) { + // Filter so we don't issue redunant copies over stride-0 modes + Tensor tCrCol_flt = filter_zeros(tCrCol); + Tensor tCcCol_flt = make_tensor(tCcCol.data(), make_layout(tCrCol_flt.shape(), tCcCol.stride())); + + Tensor tCgCol = sm90_partition_for_epilogue(gCol_l(_,_,l), epi_tile, tiled_copy, thread_idx); + Tensor tCgCol_flt = filter_zeros(tCgCol); + + // NOTE: atomic reduction is performed in the output type + using ConvertOutput = NumericConverter; + using ReduceOutput = GmemReduceFn; + ConvertOutput convert_output{}; + ReduceOutput reduce_output{}; + + if (is_reduced_lane) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tCrCol_flt); ++i) { + if (elem_less(tCcCol_flt(i), residue_tCcCol)) { + reduce_output(&tCgCol_flt(i), convert_output(tCrCol_flt(i))); + } + } + } + sync_fn(); + } + + // + // 2. One warp in N, skip threadblock smem reduction + // + else if constexpr (decltype(size<1>(warp_layout_MN))::value <= 1) { + // Dump warp reduction to gmem workspace + using ElementGmem = cute::conditional_t; + Tensor tCgBuf = sm90_partition_for_epilogue(gBuf_nl(_,_,n,l), epi_tile, tiled_copy, thread_idx); + if (is_reduced_lane) { + copy_aligned(tCrCol, recast(tCgBuf)); + } + sync_fn(); + } + + // + // 2. Multiple warps in N, do threadblock smem reduction + // + else { + Tensor sBuf = make_tensor(make_smem_ptr(raw_pointer_cast(smem_buffer.data())), sBuf_layout); + static_assert(decltype(cosize(sBuf.layout()))::value * sizeof(ElementCompute) <= + decltype(cosize(smem_buffer.layout()))::value * sizeof(typename remove_cvref_t::value_type), + "smem reduction buffer not large enough, use a larger epilogue tile"); + sync_fn(); + + // Dump warp reduction to smem workspace + Tensor tCsBuf = sm90_partition_for_epilogue(sBuf(_,_,get<1>(warp_mn)), epi_tile, tiled_copy, thread_idx); + if (is_reduced_lane) { + copy_aligned(tCrCol, tCsBuf); + } + sync_fn(); + + constexpr int SmemFragSize = cute::max(size_t{1}, sizeof(uint32_t) / sizeof(ElementCompute)); + using FragmentSmem = Array; + using VectorSmem = uint_bit_t>; + using ReduceSmem = GmemReduceFn; + ReduceSmem reduce_smem{}; + + Tensor sBuf_frg = recast(filter_zeros(sBuf)); + Tensor sBuf_vec = recast(filter_zeros(sBuf)); + constexpr int FragsPerCol = decltype(size<0>(sBuf_frg))::value; + + // Do the threadblock smem reduction + CUTLASS_PRAGMA_UNROLL + for (int reduction_cols = size<1>(warp_layout_MN) / 2; reduction_cols > 1; reduction_cols /= 2) { + int FragsPerReduction = reduction_cols * FragsPerCol; + CUTLASS_PRAGMA_NO_UNROLL + for (int frg_idx = thread_idx; frg_idx < FragsPerReduction; frg_idx += size(tiled_copy)) { + FragmentSmem frg_smem = reduce_smem(sBuf_frg(frg_idx), sBuf_frg(frg_idx + FragsPerReduction)); + sBuf_vec(frg_idx) = reinterpret_cast(frg_smem); + } + sync_fn(); + } + + // Do final smem reduction and dump to gmem workspace + using VectorGmem = cute::conditional_t; + Tensor gBuf_vec = recast(filter(gBuf_nl(_,_,n,l))); + CUTLASS_PRAGMA_NO_UNROLL + for (int frg_idx = thread_idx; frg_idx < FragsPerCol; frg_idx += size(tiled_copy)) { + FragmentSmem frg_smem = reduce_smem(sBuf_frg(frg_idx), sBuf_frg(frg_idx + FragsPerCol)); + gBuf_vec(frg_idx) = reinterpret_cast(frg_smem); + } + sync_fn(); + } + + // + // 3. Increment atomic counters to signal final gmem reduction + // + if constexpr (not IsAtomic && FinalReduction) { + // Ensure gmem writes are visible to other threads before incrementing counter + __threadfence(); + sync_fn(); + // Collective thread 0 increments atomic tile counter and copies value to smem + int* prev_tile_count = reinterpret_cast(raw_pointer_cast(smem_buffer.data())); + if (thread_idx == 0) { + *prev_tile_count = atomicAdd(¶ms.tile_counters[m], 1); + } + sync_fn(); + // Broadcast tile count to other threads in CTA and determine final reduction status + do_final_reduction = *prev_tile_count == size<2>(gBuf_nl) * size<3>(gBuf_nl) - 1; + sync_fn(); + } + } + + CUTLASS_DEVICE void + end() { + // + // 4. Do final gmem reduction if necessary + // + if constexpr (not IsAtomic && FinalReduction) { + if (not do_final_reduction) { + return; + } + + auto& [ref_src, tCrCol, tCcCol, gCol_l, cCol, gBuf_nl, sBuf_layout, + lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, + tile_coord_mnkl, residue_cCol, residue_tCcCol, epi_tile, tiled_copy, thread_idx] = args_tuple; + + using ReduceOutput = GmemReduceFn; + using ConvertOutput = NumericConverter; + ReduceOutput reduce_output{}; + ConvertOutput convert_output{}; + + // Reduction over batches + if (size<2>(stride(gCol_l)) == 0) { + CUTLASS_PRAGMA_NO_UNROLL + for (int m = thread_idx; m < size<0>(gBuf_nl); m += size(tiled_copy)) { + Tensor tRgBuf_nl = gBuf_nl(m,_0{},_,_); + ElementCompute output = tRgBuf_nl(_0{}); + CUTLASS_PRAGMA_NO_UNROLL + for (int nl = 1; nl < size(tRgBuf_nl); ++nl) { + output = reduce_output(output, tRgBuf_nl(nl)); + } + if (elem_less(cCol(m,_0{}), residue_cCol)) { + gCol_l(m,_0{},_0{}) = convert_output(output); + } + } + } + // No reduction over batches + else { + CUTLASS_PRAGMA_NO_UNROLL + for (int m = thread_idx; m < size<0>(gBuf_nl); m += size(tiled_copy)) { + bool do_store = elem_less(cCol(m,_0{}), residue_cCol); + CUTLASS_PRAGMA_NO_UNROLL + for (int l = 0; l < size<3>(gBuf_nl); ++l) { + Tensor tRgBuf_n = gBuf_nl(m,_0{},_,l); + ElementCompute output = tRgBuf_n(_0{}); + CUTLASS_PRAGMA_NO_UNROLL + for (int n = 1; n < size(tRgBuf_n); ++n) { + output = reduce_output(output, tRgBuf_n(n)); + } + if (do_store) { + gCol_l(m,_0{},l) = convert_output(output); + } + } + } + } + + } + } + + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + Layout ref_layout_MN = [&] () { + if constexpr (ReferenceSrc) { return get<0>(args.tiled_copy.get_layoutS_MN()); } + else { return get<0>(args.tiled_copy.get_layoutD_MN()); } + }(); // tile_mn -> tv_idx + + // Get the MN layout + coord of lanes to determine shuffle reduction iterations + using _W = Int; + Layout tv2lane = Layout,_W,_1>,Stride<_1,_0,_0>>{}; // tv_idx -> lane_idx + Layout ref2lane = composition(tv2lane, ref_layout_MN); // tile_mn -> lane_idx + Layout lane_layout_MN = make_layout(filter(get<0>(ref2lane)), filter(get<1>(ref2lane))); // lane_mn -> lane_idx + Layout inv_lane_layout_MN = right_inverse(lane_layout_MN); // lane_idx -> lane_mn + int lane_idx = canonical_lane_idx(); + auto lane_mn = idx2crd(inv_lane_layout_MN(lane_idx), shape(lane_layout_MN)); + + // Get the MN layout + coord of warps to determine smem reduction iterations + Layout tv2warp = Layout,_W,_1>,Stride<_0,_1,_0>>{}; // tv_idx -> warp_idx + Layout ref2warp = composition(tv2warp, ref_layout_MN); // tile_mn -> warp_idx + Layout warp_layout_MN = make_layout(filter(get<0>(ref2warp)), filter(get<1>(ref2warp))); // warp_mn -> warp_idx + Layout inv_warp_layout_MN = right_inverse(warp_layout_MN); // warp_idx -> warp_mn + int warp_idx = args.thread_idx / NumThreadsPerWarp; + auto warp_mn = idx2crd(inv_warp_layout_MN(warp_idx), shape(warp_layout_MN)); + + // Partition output gmem and register tensors + auto [tile_M, tile_N, tile_K] = args.tile_shape_mnk; + auto [M, N, K, L] = args.problem_shape_mnkl; + auto [m, n, k, l] = args.tile_coord_mnkl; + + Tensor mCol = make_tensor(make_gmem_ptr(params.ptr_col), make_shape(M,N,L), params.dCol); // (M,N,L) + Tensor gCol_l = local_tile(mCol, take<0,2>(args.tile_shape_mnk), make_coord(m,n,_)); // (CTA_M,CTA_N,L) + Tensor tCgCol = sm90_partition_for_epilogue( // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + gCol_l(_,_,l), args.epi_tile, args.tiled_copy, args.thread_idx); + Tensor tCrCol = make_tensor_like(tCgCol); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + fill(tCrCol, params.reduction_identity); + + // Partition gmem+smem reduction buffer tensors + Layout gBuf_layout = make_layout(take<0,2>(args.tile_shape_mnk), make_stride(_1{}, _0{})); + Layout mBuf_layout = blocked_product(gBuf_layout, make_layout(ceil_div(make_shape(M,N,L), shape(gBuf_layout)))); + Tensor mBuf = make_tensor(make_gmem_ptr(params.reduction_buffer), mBuf_layout); // (ceil_M,ceil_N,L) + Tensor gBuf_nl = local_tile(mBuf, take<0,2>(args.tile_shape_mnk), make_coord(m,_,_)); // (CTA_M,CTA_N,REST_N,L) + Layout sBuf_layout = blocked_product(gBuf_layout,make_layout(make_shape(_1{},_1{},size<1>(warp_layout_MN)))); // (CTA_M,CTA_N,WARPS_N) + + auto args_tuple = make_tuple( + bool_constant{}, cute::move(tCrCol), args.tCcD, gCol_l, args.cD, gBuf_nl, sBuf_layout, + lane_layout_MN, lane_mn, warp_layout_MN, warp_mn, + args.tile_coord_mnkl, args.residue_cD, args.residue_tCcD, args.epi_tile, args.tiled_copy, args.thread_idx); + return ConsumerStoreCallbacks(std::move(args_tuple), params); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Batch matrix reduction +template < + int Stages, + class EpilogueTile, + class Element, + class StrideMNL, + class CopyOpR2S, + class SmemLayoutAtom, + int Alignment = 128 / sizeof_bits_v, + bool EnableNullptr = true // Noop on nullptr params +> +struct Sm90MatrixReduction; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::fusion + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp new file mode 100644 index 0000000000..48f4756d1f --- /dev/null +++ b/include/cutlass/epilogue/fusion/sm90_visitor_tma_warpspecialized.hpp @@ -0,0 +1,1137 @@ +/*************************************************************************************************** + * Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Visitor tree operation base implementation to enable composable fusions + for the sm90 TMA warp-specialized (ws) epilogue +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/workspace.h" + +#include "cute/tensor.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::fusion { + +using namespace cute; +using cute::tuple; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace detail { + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Partitioning Helpers +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class CtaTileMN, + class EpilogueTile, + class TiledCopy +> +CUTLASS_HOST_DEVICE +constexpr auto +sm90_partition_for_epilogue( + CtaTileMN cT, // (CTA_M,CTA_N,...) + EpilogueTile epi_tile, // (EPI_TILE_M,EPI_TILE_N) + TiledCopy tiled_copy, + int thread_idx) { + ThrCopy thread_copy = tiled_copy.get_thread_slice(thread_idx); + Tensor cT_epi = flat_divide(cT, epi_tile); // (EPI_TILE_M,EPI_TILE_N,EPI_M,EPI_N,...) + if constexpr (ReferenceSrc) { + return thread_copy.partition_S(cT_epi); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,...) + } + else { + return thread_copy.partition_D(cT_epi); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N,...) + } +} + +template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class Engine, class LayoutMNL, + class TileShapeMNK, + class TileCoordMNKL, + class EpilogueTile, + class TiledCopy +> +CUTLASS_HOST_DEVICE +constexpr auto +sm90_partition_for_epilogue( + Tensor mT, // (M,N,L) + TileShapeMNK tile_shape_mnk, // (CTA_M,CTA_N,CTA_K) + TileCoordMNKL tile_coord_mnkl, // (m,n,k,l) + EpilogueTile epi_tile, // (EPI_TILE_M,EPI_TILE_N) + TiledCopy tiled_copy, + int thread_idx) { + auto [m, n, k, l] = tile_coord_mnkl; + auto coord_shape = + make_coord(m, n, l) + ; + Tensor cT = local_tile(mT, take<0,2>(tile_shape_mnk), coord_shape); // (CTA_M,CTA_N) + Tensor tCcT = + sm90_partition_for_epilogue(cT, epi_tile, tiled_copy, thread_idx); // (CPY,CPY_M,CPY_N,EPI_M,EPI_N) + + return tCcT; +} + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Visitor Implementation +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + class ProblemShapeMNKL, + class TileShapeMNK, + class TileCoordMNKL, + class TiledMma, + class EpilogueTile +> +struct ProducerLoadArgs { + ProblemShapeMNKL problem_shape_mnkl; + TileShapeMNK tile_shape_mnk; + TileCoordMNKL tile_coord_mnkl; + TiledMma tiled_mma; + EpilogueTile epi_tile; + int thread_idx; + + CUTLASS_DEVICE + ProducerLoadArgs( + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_mnk, + TileCoordMNKL tile_coord_mnkl, + TiledMma tiled_mma, + EpilogueTile epi_tile, + int thread_idx) + : problem_shape_mnkl(problem_shape_mnkl), + tile_shape_mnk(tile_shape_mnk), + tile_coord_mnkl(tile_coord_mnkl), + tiled_mma(tiled_mma), + epi_tile(epi_tile), + thread_idx(thread_idx) {} +}; + +template< + class ProblemShapeMNKL, + class TileShapeMNK, + class TileCoordMNKL, + class TiledMma, + class EpilogueTile, + class TiledCopy, + class CoordTensor, + class Residue, + class ThrCoordTensor, + class ThrResidue, + class ThrSrcTensor +> +struct ConsumerStoreArgs { + ProblemShapeMNKL problem_shape_mnkl; + TileShapeMNK tile_shape_mnk; + TileCoordMNKL tile_coord_mnkl; + TiledMma tiled_mma; + EpilogueTile epi_tile; + TiledCopy tiled_copy; + CoordTensor cD; + Residue residue_cD; + ThrCoordTensor tCcD; + ThrResidue residue_tCcD; + ThrSrcTensor & tCrC; + int thread_idx; + + CUTLASS_DEVICE + ConsumerStoreArgs( + ProblemShapeMNKL problem_shape_mnkl, + TileShapeMNK tile_shape_mnk, + TileCoordMNKL tile_coord_mnkl, + TiledMma tiled_mma, + EpilogueTile epi_tile, + TiledCopy tiled_copy, + CoordTensor cD, + Residue residue_cD, + ThrCoordTensor tCcD, + ThrResidue residue_tCcD, + ThrSrcTensor & tCrC, + int thread_idx) + : problem_shape_mnkl(problem_shape_mnkl), + tile_shape_mnk(tile_shape_mnk), + tile_coord_mnkl(tile_coord_mnkl), + tiled_mma(tiled_mma), + epi_tile(epi_tile), + tiled_copy(tiled_copy), + cD(cD), + residue_cD(residue_cD), + tCcD(tCcD), + residue_tCcD(residue_tCcD), + tCrC(tCrC), + thread_idx(thread_idx) {} +}; + +template +struct Sm90VisitorImplBase { + // Shared memory allocation + using SharedStorage = tuple; + // Host side fusion arguments + using Arguments = tuple; + // Device side fusion params (Kernel-entry API) + using Params = tuple; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + uint8_t* op_workspace = reinterpret_cast(workspace); + return transform_apply(tuple{}, args, + [&] (auto&& op, auto const& op_args) { + using Op = cute::remove_cvref_t; + auto ret = Op::to_underlying_arguments(problem_shape, op_args, op_workspace); + if (op_workspace != nullptr) { + size_t op_workspace_size = Op::get_workspace_size(problem_shape, op_args); + op_workspace += round_nearest(op_workspace_size, MinWorkspaceAlignment); + } + return ret; + }, + [] (auto&&... op_params) { return cute::make_tuple(op_params...); } + ); + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return transform_apply(tuple{}, args, + [&] (auto&& op, auto const& op_args) { + using Op = cute::remove_cvref_t; + return Op::can_implement(problem_shape, op_args); + }, + [&] (auto&&... implementable) { + return (true && ... && implementable); + } + ); + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return transform_apply(tuple{}, args, + [&] (auto&& op, auto const& op_args) { + using Op = cute::remove_cvref_t; + size_t op_workspace_size = Op::get_workspace_size(problem_shape, op_args); + return round_nearest(op_workspace_size, MinWorkspaceAlignment); + }, + [&] (auto&&... op_workspace_size) { + return (0 + ... + op_workspace_size); + } + ); + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + Status status = Status::kSuccess; + uint8_t* op_workspace = reinterpret_cast(workspace); + return transform_apply(tuple{}, args, + // Initialize each operation's workspace, stopping at the first error + [&] (auto&& op, auto const& op_args) { + if (status != Status::kSuccess) { + return status; + } + + using Op = cute::remove_cvref_t; + status = Op::initialize_workspace(problem_shape, op_args, op_workspace, stream, cuda_adapter); + if (op_workspace != nullptr) { + size_t op_workspace_size = Op::get_workspace_size(problem_shape, op_args); + op_workspace += round_nearest(op_workspace_size, MinWorkspaceAlignment); + } + return status; + }, + // Return the final status + [&] (auto const&...ops) { return status; } + ); + } + + CUTLASS_HOST_DEVICE + Sm90VisitorImplBase() {} + + CUTLASS_HOST_DEVICE + Sm90VisitorImplBase(Params const& params, SharedStorage const& shared_storage) + : ops(transform_apply(tuple{}, params, shared_storage, + [] (auto&& op, auto const& op_params, auto&& op_storage) { + using Op = cute::remove_cvref_t; + return Op(op_params, op_storage); + }, + [] (auto&&... ops) { return cute::make_tuple(ops...); } + )) {} + + // Ops can store kernel persistent variables (e.g. descriptors, scalars, wave counters) + tuple ops; +}; + +template +struct Sm90VisitorImpl : Sm90VisitorImplBase { + + using Impl = Sm90VisitorImplBase; + using Params = typename Impl::Params; + using SharedStorage = typename Impl::SharedStorage; + + CUTLASS_HOST_DEVICE + Sm90VisitorImpl() {} + + CUTLASS_HOST_DEVICE + Sm90VisitorImpl(Params const& params, SharedStorage const& shared_storage) + : Impl(params, shared_storage) {} + + using Impl::ops; + + // + // Queries for kernel runtime + // + + // Is a specialized warp for producer TMA loads needed + // e.g. Aux tensor loads, broadcasts using TMA bulk copy + // This condition cannot change between work tiles because it is used + // to determine whether the load warp should exit early or not + // e.g. for batched beta this must always be true regardless of current batch idx + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return cute::apply(ops, + [] (auto const&... op) { + return (false || ... || op.is_producer_load_needed()); + } + ); + } + + // Is a producer TMA load specifically for C needed + // If this is true then is_producer_load_needed must also be true + // This condition can change between work tiles because it is only used + // to determine whether the TMA and smem loads for C of a given tile should happen + // e.g. for batched beta this can be false depending on current batch idx + CUTLASS_DEVICE bool + is_C_load_needed() const { + return cute::apply(ops, + [] (auto const&... op) { + return (false || ... || op.is_C_load_needed()); + } + ); + } + + // + // Producer load callbacks, called by the epilogue load warp. + // Operations usually only define this if TMA load is needed. Most operations will reuse this empy implementation + // Load callbacks are responsible for issuing corresponding mbarrier expect-tx ops for any TMA loads issued, but + // are not responsible for issuing the producer_commit barrier arrival, which is issued by the collective instead + // If this is non-empty, is_producer_load_needed must be true. + // + template + struct ProducerLoadCallbacks { + // Callbacks can store non-persistent variables (e.g. tensors) or copies of persistent variables + CallbacksTuple callbacks_tuple; + + // Before entry of the subtile load loop + CUTLASS_DEVICE void + begin() { + for_each(callbacks_tuple, + [&] (auto& callbacks) { + callbacks.begin(); + } + ); + } + + // Entry of the subtile load loop. Aux loads usually performed here + // Upon entry the producer acquire of the current subtile lock has completed. + // Upon exit all TMA loads for this subtile must have been issued, with corresponding expect-tx operations + CUTLASS_DEVICE void + step(uint64_t* full_mbarrier_ptr, int epi_m, int epi_n, int load_iteration, bool issue_tma_load) { + for_each(callbacks_tuple, + [&] (auto& callbacks) { + callbacks.step(full_mbarrier_ptr, epi_m, epi_n, load_iteration, issue_tma_load); + } + ); + } + + // Exit of the subtile load loop. + CUTLASS_DEVICE void + end() { + for_each(callbacks_tuple, + [] (auto& callbacks) { + callbacks.end(); + } + ); + } + }; + + // Producer load callbacks factory + // All operations must redefine this, but most can just dispatch to the base impl + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return transform_apply(ops, + [&] (auto& op) { + return op.get_producer_load_callbacks(args); + }, + [] (auto&&... callbacks) { + auto callbacks_tuple = cute::make_tuple(callbacks...); + return ProducerLoadCallbacks{callbacks_tuple}; + } + ); + } + + // + // Consumer store callbacks, called by the epilogue store warps. + // All operations must redefine this, with optional inheritance from this empty implementation. + // + template + struct ConsumerStoreCallbacks { + // Callbacks can store non-persistent variables (e.g. tensors) or copies of persistent variables + CallbacksTuple callbacks_tuple; + + // Before entry of subtile store loop. Gmem broadcasts usually performed here. + CUTLASS_DEVICE void + begin() { + for_each(callbacks_tuple, + [] (auto& callbacks) { + callbacks.begin(); + } + ); + } + + // Start of subtile store iteration + CUTLASS_DEVICE void + begin_loop(int epi_m, int epi_n) { + for_each(callbacks_tuple, + [&] (auto& callbacks) { + callbacks.begin_loop(epi_m, epi_n); + } + ); + } + + // Before visit callback. Smem broadcasts usually performed here. + // Upon entry, all producer loads for this subtile are completed and visible. + CUTLASS_DEVICE void + previsit(int epi_m, int epi_n, int load_iteration, bool is_producer_load_needed) { + for_each(callbacks_tuple, + [&] (auto& callbacks) { + callbacks.previsit(epi_m, epi_n, load_iteration, is_producer_load_needed); + } + ); + } + + // Perform the fused elementwise computation + template + CUTLASS_DEVICE auto // returns an Array + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, + Array const&... frg_inputs) // depends on the N-naryness of the op + = delete; // Must be implemented for each operation + + // After visit call. Smem reductions usually performed here + // reduction_buffer is an arbitrary smem tensor that can be used for workspace + // It is each nodes reponsibility to assert that this buffer is sufficiently sized + // and to ensure that this buffer is no longer needed upon callback exit + // i.e. results are synchronized and no longer in the reduction buffer + // + // visit_results is a rmem tensor that contains the results of visit() for an entire + // on the current epilogue subtile + template + CUTLASS_DEVICE void + reduce(STensor&& reduction_buffer, SyncFn const& sync_fn, int epi_m, int epi_n, bool is_last_iteration, VTensor visit_results) { + for_each(callbacks_tuple, + [&] (auto& callbacks) { + callbacks.reduce(reduction_buffer, sync_fn, epi_m, epi_n, is_last_iteration, visit_results); + } + ); + } + + // After reduce call, before smem async fence. Smem stores usually performed here. + // Upon exit, all smem stores for TMA must have been issued + CUTLASS_DEVICE void + postreduce(int epi_m, int epi_n, int store_iteration, bool issue_smem_store) { + for_each(callbacks_tuple, + [&] (auto& callbacks) { + callbacks.postreduce(epi_m, epi_n, store_iteration, issue_smem_store); + } + ); + } + + // After smem async fence, before TMA store commit. Aux stores usually performed here + // Upon exit, all TMA stores for this subtile must have been issued + // Because of the TMA store delay optimization, this entry point must ONLY be used for TMA stores + // other gmem stores can be placed in the reduce or postreduce entry points + CUTLASS_DEVICE void + tma_store(int epi_m, int epi_n, int store_iteration, bool issue_tma_store) { + for_each(callbacks_tuple, + [&] (auto& callbacks) { + callbacks.tma_store(epi_m, epi_n, store_iteration, issue_tma_store); + } + ); + } + + // End of subtile store iteration + CUTLASS_DEVICE void + end_loop(int epi_m, int epi_n) { + for_each(callbacks_tuple, + [&] (auto& callbacks) { + callbacks.end_loop(epi_m, epi_n); + } + ); + } + + // Exit of subtile store loop. Gmem reductions usually performed here. + CUTLASS_DEVICE void + end() { + for_each(callbacks_tuple, + [&] (auto& callbacks) { + callbacks.end(); + } + ); + } + }; + + // Consumer store callbacks factory + // All operations must redefine this + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + return transform_apply(ops, + [&] (auto& op) { + return op.template get_consumer_store_callbacks(args); + }, + [] (auto&&... callbacks) { + auto callbacks_tuple = cute::make_tuple(callbacks...); + return ConsumerStoreCallbacks{callbacks_tuple}; + } + ); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Convenience aliases +using EmptyProducerLoadCallbacks = Sm90VisitorImpl<>::ProducerLoadCallbacks>; +using EmptyConsumerStoreCallbacks = Sm90VisitorImpl<>::ConsumerStoreCallbacks>; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace detail + +using namespace detail; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// Tree visitor +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +template +struct Sm90TreeVisitor : Sm90VisitorImpl { + + using Impl = Sm90VisitorImpl; + using Params = typename Impl::Params; + using SharedStorage = typename Impl::SharedStorage; + + CUTLASS_HOST_DEVICE + Sm90TreeVisitor() {} + + CUTLASS_HOST_DEVICE + Sm90TreeVisitor( + Params const& params, + SharedStorage const& shared_storage) + : Impl(params, shared_storage) {} + + template + struct ConsumerStoreCallbacks : CallbacksImpl { + CUTLASS_DEVICE + ConsumerStoreCallbacks(CallbacksImpl&& impl) + : CallbacksImpl(cute::forward(impl)) {} + + using CallbacksImpl::callbacks_tuple; + + template + CUTLASS_DEVICE auto + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + constexpr int Rm1 = sizeof...(ChildOps); + return cute::detail::tapply(callbacks_tuple, + [&] (auto& child_callbacks) { + return child_callbacks.visit(frg_acc, epi_v, epi_m, epi_n); // child ops must be nullary (e.g. loads, trees) + }, + [&] (auto&&... frg_inputs) { + return get(callbacks_tuple).visit(frg_acc, epi_v, epi_m, epi_n, frg_inputs...); + }, + make_seq{} // restrict the transform to R-1 child ops, apply is for node op + ); + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + auto callbacks_tuple = Sm90VisitorImpl:: + template get_consumer_store_callbacks(args); + return ConsumerStoreCallbacks(std::move(callbacks_tuple)); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// +// +// DAG visitors +// +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Most DAG fusions can be represented as a set of output trees with a common input tree +// The common input is first evaluated, then the result is passed as the acc fragment to the output trees +template +struct Sm90SplitTreeVisitor : Sm90VisitorImpl { + + using Sm90VisitorImpl::Sm90VisitorImpl; + + template + struct ConsumerStoreCallbacks : CallbacksImpl { + CUTLASS_DEVICE + ConsumerStoreCallbacks(CallbacksImpl&& impl) + : CallbacksImpl(cute::forward(impl)) {} + + using CallbacksImpl::callbacks_tuple; + + template + CUTLASS_DEVICE auto + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + Array frg_input = get<0>(callbacks_tuple).visit(frg_acc, epi_v, epi_m, epi_n); + + constexpr int Rm2 = sizeof...(AuxOutTrees); + cute::for_each(make_seq{}, // restrict the sequence to aux out trees + [&] (auto I) { + get(callbacks_tuple).visit(frg_input, epi_v, epi_m, epi_n); + } + ); + + return get(callbacks_tuple).visit(frg_input, epi_v, epi_m, epi_n); + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + auto callbacks_tuple = Sm90VisitorImpl:: + template get_consumer_store_callbacks(args); + return ConsumerStoreCallbacks(std::move(callbacks_tuple)); + } +}; +///////////////////////////////////////////////////////////////////////////////////////////////// + +template< + // deducing the output type for all the nodes is tricky so we just convert them all to a common type + // if multiple compute types are needed then split into multiple subgraphs grouped by type + class ElementCompute, + class EdgeTuple, // tuple of int_sequence, each sequence is the children indices (indexed by topological order) for each node + class... Ops // in topological order, last op is the output. EdgeTuple must match this order +> +struct Sm90TopologicalVisitor : Sm90VisitorImpl { + static_assert(is_static_v); + static_assert(cute::rank(EdgeTuple{}) == sizeof...(Ops)); + static_assert(sizeof...(Ops) > 1); + + using Sm90VisitorImpl::Sm90VisitorImpl; + + template + struct ConsumerStoreCallbacks : CallbacksImpl { + CUTLASS_DEVICE + ConsumerStoreCallbacks(CallbacksImpl&& impl) + : CallbacksImpl(cute::forward(impl)) {} + + using CallbacksImpl::callbacks_tuple; + + template + CUTLASS_DEVICE auto + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n) { + constexpr int Rm1 = sizeof...(Ops) - 1; + auto frg_compute_tuple = cute::repeat(Array{}); + + return cute::detail::tapply(EdgeTuple{}, callbacks_tuple, frg_compute_tuple, + // Visit the first R-1 ops in topological order + [&] (auto&& edge_seq, auto& callbacks, auto& frg_compute) { + frg_compute = cute::detail::apply(frg_compute_tuple, + // Compute the current op with children inputs + [&] (auto const&... frg_inputs) { + auto frg_output = callbacks.visit(frg_acc, epi_v, epi_m, epi_n, frg_inputs...); + using ElementOutput = typename decltype(frg_output)::Element; + using ConvertOutput = NumericArrayConverter; + ConvertOutput convert_output{}; + + return convert_output(frg_output); + }, + // Get inputs in the sequence given by the children indices of the current op + edge_seq + ); + return frg_compute; // unused + }, + // Visit the last op + [&] (auto const&...ops) { + return cute::detail::apply(frg_compute_tuple, + // Compute the last op with children inputs + [&] (auto const&... frg_inputs) { + return get(callbacks_tuple).visit(frg_acc, epi_v, epi_m, epi_n, frg_inputs...); + }, + // Get inputs in the sequence given by the children indices of the last op + get(EdgeTuple{}) + ); + }, + // Transform to visit R-1 ops, apply to visit last op + make_seq{} + ); + } + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + auto callbacks_tuple = Sm90VisitorImpl:: + template get_consumer_store_callbacks(args); + return ConsumerStoreCallbacks(std::move(callbacks_tuple)); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Base specializations so we can have standard layout params and simple aggregate initializers +namespace detail { + +template +struct Sm90VisitorImplBase { + + // Retain tuple for SharedStorage because empty structs have 1B alignment + // tuples use multiple inheritance, avoids this problem + using SharedStorage = tuple< + typename Op0::SharedStorage + >; + + struct Arguments { + typename Op0::Arguments op_0; + }; + + struct Params { + typename Op0::Params op_0; + }; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return Params{ + Op0::to_underlying_arguments(problem_shape, args.op_0, workspace) + }; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return Op0::can_implement(problem_shape, args.op_0); + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + size_t workspace_size = 0; + workspace_size += Op0::get_workspace_size(problem_shape, args.op_0); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + return workspace_size; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + Status status = Status::kSuccess; + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + status = Op0::initialize_workspace(problem_shape, args.op_0, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += Op0::get_workspace_size(problem_shape, args.op_0); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + return status; + } + + CUTLASS_HOST_DEVICE + Sm90VisitorImplBase() {} + + CUTLASS_HOST_DEVICE + Sm90VisitorImplBase(Params const& params, SharedStorage const& shared_storage) + : ops({ + Op0(params.op_0, get<0>(shared_storage)) + }) {} + + tuple ops; +}; + +template +struct Sm90VisitorImplBase { + + using SharedStorage = tuple< + typename Op0::SharedStorage, + typename Op1::SharedStorage + >; + + struct Arguments { + typename Op0::Arguments op_0; + typename Op1::Arguments op_1; + }; + + struct Params { + typename Op0::Params op_0; + typename Op1::Params op_1; + }; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + size_t op_0_workspace_size = Op0::get_workspace_size(problem_shape, args.op_0); + uint8_t* op_0_workspace = reinterpret_cast(workspace); + uint8_t* op_1_workspace = op_0_workspace + op_0_workspace_size; + return Params{ + Op0::to_underlying_arguments(problem_shape, args.op_0, op_0_workspace), + Op1::to_underlying_arguments(problem_shape, args.op_1, op_1_workspace) + }; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return Op0::can_implement(problem_shape, args.op_0) && + Op1::can_implement(problem_shape, args.op_1); + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + size_t workspace_size = 0; + workspace_size += Op0::get_workspace_size(problem_shape, args.op_0); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + workspace_size += Op1::get_workspace_size(problem_shape, args.op_1); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + return workspace_size; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + Status status = Status::kSuccess; + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + status = Op0::initialize_workspace(problem_shape, args.op_0, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += Op0::get_workspace_size(problem_shape, args.op_0); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + status = Op1::initialize_workspace(problem_shape, args.op_1, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += Op1::get_workspace_size(problem_shape, args.op_1); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + return status; + } + + CUTLASS_HOST_DEVICE + Sm90VisitorImplBase() {} + + CUTLASS_HOST_DEVICE + Sm90VisitorImplBase(Params const& params, SharedStorage const& shared_storage) + : ops({ + Op0(params.op_0, get<0>(shared_storage)), + Op1(params.op_1, get<1>(shared_storage)) + }) {} + + tuple ops; +}; + +template +struct Sm90VisitorImplBase { + + using SharedStorage = tuple< + typename Op0::SharedStorage, + typename Op1::SharedStorage, + typename Op2::SharedStorage + >; + + struct Arguments { + typename Op0::Arguments op_0; + typename Op1::Arguments op_1; + typename Op2::Arguments op_2; + }; + + struct Params { + typename Op0::Params op_0; + typename Op1::Params op_1; + typename Op2::Params op_2; + }; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + size_t op_0_workspace_size = Op0::get_workspace_size(problem_shape, args.op_0); + size_t op_1_workspace_size = Op1::get_workspace_size(problem_shape, args.op_1); + uint8_t* op_0_workspace = reinterpret_cast(workspace); + uint8_t* op_1_workspace = op_0_workspace + op_0_workspace_size; + uint8_t* op_2_workspace = op_1_workspace + op_1_workspace_size; + return Params{ + Op0::to_underlying_arguments(problem_shape, args.op_0, op_0_workspace), + Op1::to_underlying_arguments(problem_shape, args.op_1, op_1_workspace), + Op2::to_underlying_arguments(problem_shape, args.op_2, op_2_workspace) + }; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return Op0::can_implement(problem_shape, args.op_0) && + Op1::can_implement(problem_shape, args.op_1) && + Op2::can_implement(problem_shape, args.op_2); + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + size_t workspace_size = 0; + workspace_size += Op0::get_workspace_size(problem_shape, args.op_0); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + workspace_size += Op1::get_workspace_size(problem_shape, args.op_1); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + workspace_size += Op2::get_workspace_size(problem_shape, args.op_2); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + return workspace_size; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + Status status = Status::kSuccess; + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + status = Op0::initialize_workspace(problem_shape, args.op_0, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += Op0::get_workspace_size(problem_shape, args.op_0); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + status = Op1::initialize_workspace(problem_shape, args.op_1, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += Op1::get_workspace_size(problem_shape, args.op_1); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + status = Op2::initialize_workspace(problem_shape, args.op_2, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += Op2::get_workspace_size(problem_shape, args.op_2); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + return status; + } + + CUTLASS_HOST_DEVICE + Sm90VisitorImplBase() {} + + CUTLASS_HOST_DEVICE + Sm90VisitorImplBase(Params const& params, SharedStorage const& shared_storage) + : ops({ + Op0(params.op_0, get<0>(shared_storage)), + Op1(params.op_1, get<1>(shared_storage)), + Op2(params.op_2, get<2>(shared_storage)) + }) {} + + tuple ops; +}; + +template +struct Sm90VisitorImplBase { + + using SharedStorage = tuple< + typename Op0::SharedStorage, + typename Op1::SharedStorage, + typename Op2::SharedStorage, + typename Op3::SharedStorage + >; + + struct Arguments { + typename Op0::Arguments op_0; + typename Op1::Arguments op_1; + typename Op2::Arguments op_2; + typename Op3::Arguments op_3; + }; + + struct Params { + typename Op0::Params op_0; + typename Op1::Params op_1; + typename Op2::Params op_2; + typename Op3::Params op_3; + }; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + size_t op_0_workspace_size = Op0::get_workspace_size(problem_shape, args.op_0); + size_t op_1_workspace_size = Op1::get_workspace_size(problem_shape, args.op_1); + size_t op_2_workspace_size = Op2::get_workspace_size(problem_shape, args.op_2); + uint8_t* op_0_workspace = reinterpret_cast(workspace); + uint8_t* op_1_workspace = op_0_workspace + op_0_workspace_size; + uint8_t* op_2_workspace = op_1_workspace + op_1_workspace_size; + uint8_t* op_3_workspace = op_2_workspace + op_2_workspace_size; + return Params{ + Op0::to_underlying_arguments(problem_shape, args.op_0, op_0_workspace), + Op1::to_underlying_arguments(problem_shape, args.op_1, op_1_workspace), + Op2::to_underlying_arguments(problem_shape, args.op_2, op_2_workspace), + Op3::to_underlying_arguments(problem_shape, args.op_3, op_3_workspace) + }; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + return Op0::can_implement(problem_shape, args.op_0) && + Op1::can_implement(problem_shape, args.op_1) && + Op2::can_implement(problem_shape, args.op_2) && + Op3::can_implement(problem_shape, args.op_3); + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + size_t workspace_size = 0; + workspace_size += Op0::get_workspace_size(problem_shape, args.op_0); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + workspace_size += Op1::get_workspace_size(problem_shape, args.op_1); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + workspace_size += Op2::get_workspace_size(problem_shape, args.op_2); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + workspace_size += Op3::get_workspace_size(problem_shape, args.op_3); + workspace_size = round_nearest(workspace_size, MinWorkspaceAlignment); + + return workspace_size; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + Status status = Status::kSuccess; + uint8_t* workspace_ptr = reinterpret_cast(workspace); + size_t workspace_offset = 0; + + status = Op0::initialize_workspace(problem_shape, args.op_0, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += Op0::get_workspace_size(problem_shape, args.op_0); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + status = Op1::initialize_workspace(problem_shape, args.op_1, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += Op1::get_workspace_size(problem_shape, args.op_1); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + status = Op2::initialize_workspace(problem_shape, args.op_2, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += Op2::get_workspace_size(problem_shape, args.op_2); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + status = Op3::initialize_workspace(problem_shape, args.op_3, workspace_ptr + workspace_offset, stream, cuda_adapter); + workspace_offset += Op3::get_workspace_size(problem_shape, args.op_3); + workspace_offset = round_nearest(workspace_offset, MinWorkspaceAlignment); + if (status != Status::kSuccess) { + return status; + } + + return status; + } + + CUTLASS_HOST_DEVICE + Sm90VisitorImplBase() {} + + CUTLASS_HOST_DEVICE + Sm90VisitorImplBase(Params const& params, SharedStorage const& shared_storage) + : ops({ + Op0(params.op_0, get<0>(shared_storage)), + Op1(params.op_1, get<1>(shared_storage)), + Op2(params.op_2, get<2>(shared_storage)), + Op3(params.op_3, get<3>(shared_storage)) + }) {} + + tuple ops; +}; + +} // namespace detail + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::fusion + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp b/include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp new file mode 100644 index 0000000000..53c0dce8ba --- /dev/null +++ b/include/cutlass/epilogue/fusion/sm90_visitor_topk_softmax.hpp @@ -0,0 +1,759 @@ +/*************************************************************************************************** + * Copyright (c) 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * SPDX-License-Identifier: BSD-3-Clause + * + * Redistribution and use in source and binary forms, with or without + * modification, are permitted provided that the following conditions are met: + * + * 1. Redistributions of source code must retain the above copyright notice, this + * list of conditions and the following disclaimer. + * + * 2. Redistributions in binary form must reproduce the above copyright notice, + * this list of conditions and the following disclaimer in the documentation + * and/or other materials provided with the distribution. + * + * 3. Neither the name of the copyright holder nor the names of its + * contributors may be used to endorse or promote products derived from + * this software without specific prior written permission. + * + * THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" + * AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE + * IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE + * DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE + * FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL + * DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR + * SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER + * CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, + * OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE + * OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE. + * + **************************************************************************************************/ + +/*! \file + \brief Visitor tree Top-K + Softmax fusion operation for sm90 TMA warp-specialized epilogue +*/ + +#pragma once + +#include "cutlass/cutlass.h" +#include "cutlass/workspace.h" + +#include "cute/tensor.hpp" +#include "sm90_visitor_tma_warpspecialized.hpp" + +///////////////////////////////////////////////////////////////////////////////////////////////// + +namespace cutlass::epilogue::fusion { + +///////////////////////////////////////////////////////////////////////////////////////////////// + +// Top-K + Softmax reduction across columns +// Performs a reduction of top-K values across N, and finally performs a softmax on them, +// and sets values not in the top-K to 0. +// +// Assumptions: +// 1. CTA_N >= N (single tile across N, the mode which is reduced) +// 2. EPI_N >= N (single epilogue tile across N, because we can reduce and revisit one +// epilogue tile at a time.) +// 3. Top-K value is either 2 or 4. +// + +namespace detail { + +// Implementations for add to sorted list and merging sorted lists, +// with fast paths for lists of size 2 and 4 (Top-2 and Top-4). +// Generic implementations may result in greater register use and branching, +// and should be avoided. +// Fast paths for Top-2 and Top-4 are written in inline PTX directly. + +CUTLASS_DEVICE +Array top_2_reduce_scalar(Array a, float scalar) { + Array out; + asm volatile( + "{\n" + " .reg .f32 mx;\n" + " .reg .pred p;\n" + " max.f32 mx, %3, %4;\n" + " setp.gtu.f32 p, %2, %4;\n" + " selp.f32 %1, mx, %2, p;\n" + " selp.f32 %0, %2, %4, p;\n" + "}\n" : "=f"(out[0]), "=f"(out[1]) : "f"(a[0]), "f"(a[1]), "f"(scalar)); + return out; +} + +CUTLASS_DEVICE +Array top_2_reduce(Array a, Array b) { + Array out; + asm volatile( + "{\n" + " .reg .v2 .f32 mx;\n" + " .reg .pred p;\n" + " max.f32 mx.x, %3, %4;\n" // max(a1, b0) + " max.f32 mx.y, %2, %5;\n" // max(a0, b1) + " setp.gtu.f32 p, %2, %4;\n" // a0 > b0 + " selp.f32 %1, mx.x, mx.y, p;\n" // a0 > b0 ? max(a1, b0) : max(a0, b1) + " selp.f32 %0, %2, %4, p;\n" // a0 > b0 ? a0 : b0 + "}\n" : "=f"(out[0]), "=f"(out[1]) : + "f"(a[0]), "f"(a[1]), "f"(b[0]), "f"(b[1])); + return out; +} + +CUTLASS_DEVICE +Array top_4_reduce_scalar(Array a, float scalar) { + Array out; + asm volatile( + "{\n" + " .reg .f32 mx;\n" // max(a3, b) + " .reg .pred p0;\n" // a0 > b + " .reg .pred p1;\n" // a1 > b + " .reg .pred p2;\n" // a2 > b + " max.f32 mx, %7, %8;\n" // max(a3, b) + " setp.gtu.f32 p0, %4, %8;\n" // a0 > b + " setp.gtu.f32 p1, %5, %8;\n" // a1 > b + " setp.gtu.f32 p2, %6, %8;\n" // a2 > b + " selp.f32 %3, mx, %6, p2;\n" // a2 > b ? max(a3, b) : a2 + " selp.f32 %2, %6, %8, p2;\n" // a1 = a2 > b ? a2 : b + " selp.f32 %2, %2, %5, p1;\n" // a1 > b ? max(a2, b) : a1 == a1 > b ? a1 : old_a1 + " selp.f32 %1, %5, %8, p1;\n" // a0 = a1 > b ? a1 : b + " selp.f32 %1, %1, %4, p0;\n" // a0 > b ? max(a1, b) : a0 == a0 > b ? a0 : old_a0 + " selp.f32 %0, %4, %8, p0;\n" // a0 = a0 > b ? a0 : b + "}\n" : + "=f"(out[0]), "=f"(out[1]), "=f"(out[2]), "=f"(out[3]) : + "f"(a[0]), "f"(a[1]), "f"(a[2]), "f"(a[3]), "f"(scalar)); + return out; +} + +CUTLASS_DEVICE +Array top_4_reduce(Array a, Array b) { + Array out; + asm volatile( + "{\n" + " .reg .f32 mxa0b1;\n" // max(a0, b1) + " .reg .f32 mxa1b0;\n" // max(a1, b0) + + " .reg .f32 mxa2b0;\n" // max(a2, b0) + " .reg .f32 mxa1b1;\n" // max(a1, b1) + " .reg .f32 mxa0b2;\n" // max(a1, b1) + + " .reg .f32 mxa1b2;\n" // max(a1, b2) + " .reg .f32 mxa2b1;\n" // max(a2, b1) + " max.f32 mxa1b2, %5, %10;\n" + " max.f32 mxa2b1, %6, %9;\n" + + " .reg .f32 mxa3b0;\n" // max(a1, b2) + " .reg .f32 mxa0b3;\n" // max(a2, b1) + " max.f32 mxa3b0, %7, %8;\n" + " max.f32 mxa0b3, %4, %11;\n" + + " .reg .pred pa0b0;\n" // a0 > b0 + " .reg .pred pa1b0;\n" // a1 > b0 + " .reg .pred pa2b0;\n" // a2 > b0 + " .reg .pred pa0b1;\n" // a0 > b1 + " .reg .pred pa1b1;\n" // a1 > b1 + " .reg .pred pa0b2;\n" // a0 > b2 + " .reg .pred pb2a0;\n" // b1 > a0 + " .reg .pred pb1a0;\n" // b1 > a0 + + " setp.gtu.f32 pa0b0, %4, %8;\n" // a0 > b0 + " setp.gtu.f32 pa1b0, %5, %8;\n" // a1 > b0 + " setp.gtu.f32 pa2b0, %6, %8;\n" // a2 > b0 + " setp.gtu.f32 pa0b1, %4, %9;\n" // a0 > b1 + " setp.gtu.f32 pa1b1, %5, %9;\n" // a1 > b1 + " setp.gtu.f32 pa0b2, %4, %10;\n" // a0 > b2 + + " not.pred pb2a0, pa0b2;\n" + " not.pred pb1a0, pa0b1;\n" + + " selp.f32 mxa1b0, %5, %8, pa1b0;\n" // max(a1, b0) + " selp.f32 mxa0b1, %4, %9, pa0b1;\n" // max(a0, b1) + + " selp.f32 mxa1b1, %5, %9, pa1b1;\n" // max(a1, b1) + " selp.f32 mxa2b0, %6, %8, pa2b0;\n" // max(a2, b0) + " selp.f32 mxa0b2, %4, %10, pa0b2;\n" // max(a0, b2) + + // a0 + " selp.f32 %0, %4, %8, pa0b0;\n" // a0 = a0 > b0 ? a0 : b0 + + // a1 + " selp.f32 %1, mxa1b0, mxa0b1, pa0b0;\n" // a1 = a0 > b0 ? max(a1, b0) : max(a0, b1) + + // a2 + " mov.f32 %2, mxa1b1;\n" // a2 = max(a1, b1) ** most likely case + " selp.f32 %2, mxa2b0, %2, pa1b0;\n" // a0 > a1 > b0 + " selp.f32 %2, mxa0b2, %2, pb1a0;\n" // b0 > b1 > a0 + + // a3 + " mov.f32 %3, mxa1b2;\n" // a3 = max(a1, b2) ** one of the most likely cases + " selp.f32 %3, mxa2b1, %3, pa1b1;\n" // a3 = a1 > b1 ? max(a2, b1) ** second most likely case + " selp.f32 %3, mxa3b0, %3, pa2b0;\n" // a0 > a1 > a2 > b0 + " selp.f32 %3, mxa0b3, %3, pb2a0;\n" // b0 > b1 > b2 > a0 + "}\n" : + "=f"(out[0]), "=f"(out[1]), "=f"(out[2]), "=f"(out[3]) : + "f"(a[0]), "f"(a[1]), "f"(a[2]), "f"(a[3]), + "f"(b[0]), "f"(b[1]), "f"(b[2]), "f"(b[3])); + return out; +} + +// Assumption: array elements are sorted in descending order +// (a[0] is the largest element in a[].) +template +CUTLASS_DEVICE +void add_element_to_desc_sorted_array(cutlass::Array& a, Element b) { + if constexpr (N == 2 && is_same_v) { + a = top_2_reduce_scalar(a, b); + } + else if constexpr (N == 4 && is_same_v) { + a = top_4_reduce_scalar(a, b); + } + else { + // slower generic path with branching, slower, and can cause register spill + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < N; ++k) { + if (a[k] <= b) { + // Shift down + CUTLASS_PRAGMA_UNROLL + for (int l = N - 1; l > k; --l) { + a[l] = a[l-1]; + } + a[k] = b; + } + } + } +} + +// Assumption: array elements are sorted in descending order +// (a[0] and b[0] are the largest elements in a[] and b[].) +template +CUTLASS_DEVICE +void merge_desc_sorted_arrays(cutlass::Array& a, const cutlass::Array& b) { + if constexpr (N == 2 && is_same_v) { + a = top_2_reduce(a, b); + } + else if constexpr (N == 4 && is_same_v) { + a = top_4_reduce(a, b); + } + else { + // slower generic path with branching, slower, and can cause register spill + int j = 0; + CUTLASS_PRAGMA_UNROLL + for (int k = 0; k < N; ++k) { + if (a[k] <= b[j]) { + // Shift down + CUTLASS_PRAGMA_UNROLL + for (int l = N - 1; l > k; --l) { + a[l] = a[l-1]; + } + a[k] = b[j]; + ++j; + } + } + } +} + +// Assumption: array elements are sorted in descending order +// (a[0] is the largest element in a[].) +template +CUTLASS_DEVICE +Element topk_logsumexp(cutlass::Array a) { + // Do one less `exp`, because we know what its result will be. + // Assume x is a set of `x_i`s, and `x_m` is the maximum of that set. + // logsumexp(x) = log(sum(x_i)) = m + log(sum(x_i - m)) = m + log(1 + sum_{i != m}(x_i - x_m)) + // Compute m + log(1 + sum_{i != m}(x_i - x_m)) + Element sum = Element(1.0); + CUTLASS_PRAGMA_UNROLL + for (int i = 1; i < N; ++i) { + sum += fast_exp(a[i] - a[0]); + } + return a[0] + fast_log(sum); +} + +CUTLASS_DEVICE +float fast_masked_softmax(float value, float minimum, float logsumexp) { + float new_value; + asm volatile( + "{\n" + " .reg .pred p0;\n" + // value >= minimum + " setp.geu.f32 p0, %1, %2;\n" + + " .reg .f32 x_lse;\n" + " .reg .f32 %%f<11>;\n" + " .reg .b32 %%r<3>;\n" + + // x_lse = value - minimum + " sub.rn.f32 x_lse, %1, %3;\n" + + // exp(x_lse) + // The following is derived from a ptx dump of expf. + // exp requires a base conversion from exp2. + " fma.rn.f32 %%f1, x_lse, 0f3BBB989D, 0f3F000000;\n" + " cvt.sat.f32.f32 %%f2, %%f1;\n" + " fma.rm.f32 %%f3, %%f2, 0f437C0000, 0f4B400001;\n" + " add.f32 %%f4, %%f3, 0fCB40007F;\n" + " neg.f32 %%f5, %%f4;\n" + " fma.rn.f32 %%f6, x_lse, 0f3FB8AA3B, %%f5;\n" + " fma.rn.f32 %%f7, x_lse, 0f32A57060, %%f6;\n" + " mov.b32 %%r1, %%f3;\n" + " shl.b32 %%r2, %%r1, 23;\n" + " mov.b32 %%f8, %%r2;\n" + " ex2.approx.ftz.f32 %%f9, %%f7;\n" + " mul.f32 %%f10, %%f9, %%f8;\n" + + // Mask or softmax + " selp.f32 %0, %%f10, 0f00000000, p0;\n" + "}\n" : "=f"(new_value) : "f"(value), "f"(minimum), "f"(logsumexp)); + return new_value; +} + +template +CUTLASS_DEVICE +Element masked_softmax(Element value, Element minimum, Element logsumexp) { + if constexpr (is_same_v) { + // Inline PTX implementation + // Significantly reduces register requirements + return fast_masked_softmax(value, minimum, logsumexp); + } + else { + return value < minimum ? Element(0.0) : fast_exp(value - logsumexp); + } +} + +} // namespace detail + +template < + int TopK, + int FragmentSize, + class CtaTileShapeMNK, + class EpilogueTile, + class ElementOutput, + class ElementCompute, + FloatRoundStyle RoundStyle, + int Alignment = 128 / sizeof_bits_v, + bool UseButterflyReduce = true +> +struct Sm90TopKSoftmaxColReduction { +private: + static_assert(is_same_v, "Fused Top-K + Softmax reduction requires FP32 accumulation."); + static_assert(TopK == 2 || TopK == 4, "Fused Top-K + Softmax reduction only supports K=2 and K=4."); + static_assert(Alignment * sizeof_bits_v % 128 == 0, "sub-16B alignment not supported yet"); + + // Reduction tensors + // We have two tensors for this EVT node: a reduction tensor and a tensor holding + // final reduction values (tCrSoftmax). The reason for this is that Top-K and Softmax + // require different reductions, but those luckily overlap. Top-K obviously needs at least + // two values (K >= 2), and softmax needs one value: logsumexp. Logsumexp is simply the log + // of sum of exponents over the set, and is equivalent to m + sum(exp(x_i - m)), where m is the + // maximum of all x_i elements. Since safe softmax for any element x_i is computed as + // softmax(x_i) = exp(x_i - m) / sum_j(exp(x_j - max)) + // we can track logsumexp instead of tracking two variables (sum of exps and the max). + // In addition, subtracting logsumexp from any element and taking its exp is equivalent to + // computing its softmax. + // + // The overlap between softmax and top-K is that we don't need to reduce logsumexp along the + // way at all, because any element not in the top-K is going to be masked out and set to 0. + // Therefore, we only reduce the top-K elements, and when done, compute their logsumexp and + // keep it, and the smallest element in the top-K for masking out non-top-K elements. + // + // This means that our final reduction result will always be 2 elements, regardless of the value + // of K: minimum of top-K, and logsumexp. + // + // For each reduction tensor, we define a new struct for readability. + + struct ReductionResult { + ElementCompute min_; + ElementCompute logsumexp_; + + CUTLASS_DEVICE + ReductionResult() { } + + CUTLASS_DEVICE + ReductionResult(ElementCompute min, ElementCompute logsumexp): + logsumexp_(logsumexp), min_(min) { } + + // Warp shuffle broadcast + CUTLASS_DEVICE + void shuffle_up_sync(uint32_t delta, int lane_id) { + static_assert(sizeof(ReductionResult) == sizeof(uint64_t)); + uint64_t r = reinterpret_cast(*this); + r = __shfl_up_sync(0xFFFFFFFF, r, delta); + *this = (lane_id - static_cast(delta) >= 0) ? reinterpret_cast(r) : *this; + } + }; + + struct TopKResult { + Array top_k_; + + CUTLASS_DEVICE + TopKResult() { + top_k_.fill(-cutlass::platform::numeric_limits::infinity()); + } + + // This is where we do the "final" reduction, where we compute + // the logsumexp for softmax, keep the smallest value in top-K, + // and discard the rest. + CUTLASS_DEVICE + ReductionResult reduce_final() const { + return ReductionResult(top_k_[TopK - 1], topk_logsumexp(top_k_)); + } + + // Butterfly reduction + CUTLASS_DEVICE + void shuffle_xor_sync(int laneMask) { + if constexpr (TopK == 2) { + static_assert(sizeof(TopKResult) == sizeof(uint64_t)); + uint64_t top_k = reinterpret_cast(*this); + top_k = __shfl_xor_sync(0xFFFFFFFF, top_k, laneMask); + auto synced_v = reinterpret_cast(top_k); + detail::merge_desc_sorted_arrays(top_k_, synced_v.top_k_); + } + else if constexpr (TopK == 4) { + static_assert(sizeof(TopKResult) == 2 * sizeof(uint64_t)); + uint64_t* top_k_ptr = reinterpret_cast(this); + uint64_t top_k_arr[2]; + top_k_arr[0] = top_k_ptr[0]; + top_k_arr[1] = top_k_ptr[1]; + top_k_arr[0] = __shfl_xor_sync(0xFFFFFFFF, top_k_arr[0], laneMask); + top_k_arr[1] = __shfl_xor_sync(0xFFFFFFFF, top_k_arr[1], laneMask); + auto synced_v = reinterpret_cast(top_k_arr); + detail::merge_desc_sorted_arrays(top_k_, synced_v.top_k_); + } + else { + TopKResult synced_v; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < TopK; ++i) { + synced_v.top_k_[i] = __shfl_xor_sync(0xFFFFFFFF, top_k_[i], laneMask); + } + detail::merge_desc_sorted_arrays(top_k_, synced_v.top_k_); + } + } + + // Warp shuffle reduction + CUTLASS_DEVICE + void shuffle_down_sync(uint32_t delta) { + if constexpr (TopK == 2) { + static_assert(sizeof(TopKResult) == sizeof(uint64_t)); + uint64_t top_k = reinterpret_cast(*this); + top_k = __shfl_down_sync(0xFFFFFFFF, top_k, delta); + auto synced_v = reinterpret_cast(top_k); + detail::merge_desc_sorted_arrays(top_k_, synced_v.top_k_); + } + else if constexpr (TopK == 4) { + static_assert(sizeof(TopKResult) == 2 * sizeof(uint64_t)); + uint64_t* top_k_ptr = reinterpret_cast(this); + uint64_t top_k_arr[2]; + top_k_arr[0] = top_k_ptr[0]; + top_k_arr[1] = top_k_ptr[1]; + top_k_arr[0] = __shfl_down_sync(0xFFFFFFFF, top_k_arr[0], delta); + top_k_arr[1] = __shfl_down_sync(0xFFFFFFFF, top_k_arr[1], delta); + auto synced_v = reinterpret_cast(top_k_arr); + detail::merge_desc_sorted_arrays(top_k_, synced_v.top_k_); + } + else { + TopKResult synced_v; + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < TopK; ++i) { + synced_v.top_k_[i] = __shfl_down_sync(0xFFFFFFFF, top_k_[i], delta); + } + detail::merge_desc_sorted_arrays(top_k_, synced_v.top_k_); + } + } + }; + +public: + struct SharedStorage { }; + + struct Arguments { }; + + struct Params { }; + + template + static constexpr Params + to_underlying_arguments(ProblemShape const& problem_shape, Arguments const& args, void* workspace) { + return {}; + } + + template + static bool + can_implement(ProblemShape const& problem_shape, Arguments const& args) { + auto [M, N, K, L] = problem_shape; + auto [tile_M, tile_N, tile_K] = CtaTileShapeMNK{}; + // Cross CTA reduction is not possible because there is no guarantee that all CTAs run + // concurrently. + // Cross epilogue tile reduction is possible, but re-visiting and applying reduction + // to accumulators is only possible for the current epilogue tile. + auto [epi_M, epi_N] = EpilogueTile{}; + return N <= tile_N && N <= epi_N && N >= TopK; + } + + template + static size_t + get_workspace_size(ProblemShape const& problem_shape, Arguments const& args) { + return 0; + } + + template + static cutlass::Status + initialize_workspace(ProblemShape const& problem_shape, Arguments const& args, void* workspace, cudaStream_t stream, + CudaHostAdapter* cuda_adapter = nullptr) { + return Status::kSuccess; + } + + CUTLASS_DEVICE bool + is_producer_load_needed() const { + return false; + } + + CUTLASS_DEVICE bool + is_C_load_needed() const { + return false; + } + + CUTLASS_HOST_DEVICE + Sm90TopKSoftmaxColReduction() { } + + CUTLASS_HOST_DEVICE + Sm90TopKSoftmaxColReduction(Params const& params, SharedStorage const& shared_storage) + : params(params) { } + + Params params; + + template + CUTLASS_DEVICE auto + get_producer_load_callbacks(ProducerLoadArgs const& args) { + return EmptyProducerLoadCallbacks{}; + } + + template + struct ConsumerStoreCallbacks : EmptyConsumerStoreCallbacks { + CUTLASS_DEVICE + ConsumerStoreCallbacks(ArgsTuple&& args_tuple, Params const& params) + : args_tuple(cute::forward(args_tuple)), + params(params) {} + + ArgsTuple args_tuple; + Params const& params; + + template + CUTLASS_DEVICE auto + visit(Array const& frg_acc, int epi_v, int epi_m, int epi_n, + Array const& frg_input) { + + auto& [tCrTopK, tCrSoftmax, tCcCol, cCol, + lane_layout_MN, lane_mn, + residue_cCol, residue_tCcCol] = args_tuple; + Tensor tCcCol_mn = tCcCol(_,_,_,epi_m,epi_n); + + using ConvertInput = NumericArrayConverter; + ConvertInput convert_input{}; + + Array frg_I = convert_input(frg_input); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + auto thread_crd = tCcCol_mn(epi_v * FragmentSize + i); + if (elem_less(thread_crd, residue_tCcCol)) { + TopKResult& tCrCol_vmn = tCrTopK(epi_v * FragmentSize + i); + detail::add_element_to_desc_sorted_array(tCrCol_vmn.top_k_, frg_I[i]); + } + } + + return frg_input; + } + + template + CUTLASS_DEVICE void + reduce(STensor&& smem_buffer, SyncFn const& sync_fn, int epi_m, int epi_n, bool is_last_iteration, VTensor visit_results) { + + auto& [tCrTopK, tCrSoftmax, tCcCol, cCol, + lane_layout_MN, lane_mn, + residue_cCol, residue_tCcCol] = args_tuple; + + // fully OOB CTA in partially OOB cluster + if (not elem_less(cCol(_0{},_0{}), residue_cCol)) { + return; + } + Tensor tCcCol_mn = tCcCol(_,_,_,epi_m,epi_n); + + // `tCrTopK` and `tCrSoftmax` have 0-strides along modes that correspond to N, + // in order to reduce along modes in the `R2S` sublayout that correspond to N. + // This means we should modify and warp-reduce them according to their co-domain instead of + // their domain. Therefore we keep a filtered view of both and use them as necessary. + auto tCrTopK_f = filter(tCrTopK); + auto tCrSoftmax_f = filter(tCrSoftmax); + + // The pattern here is: reduce Top-K first, then compute logsumexp, keep it and the + // last element of Top-K, use the latter to mask the visited results, and the former + // to apply softmax. + // + // This gives us two options: reduce the Top-K with warp shuffles, have the reduced + // lanes compute logsumexp and pair it with the last Top-K element, and broadcast + // the result back using warp shuffles. + // + // Alternatively, we can do a butterfly reduction over Top-K, and have all lanes + // compute their own logsumexp and skip the broadcast. + if constexpr (UseButterflyReduce) { + // + // 1. Butterfly reduction + // + CUTLASS_PRAGMA_UNROLL + for (int j = 1; j < size<1>(lane_layout_MN); j *= 2) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tCrTopK_f); ++i) { + tCrTopK_f(i).shuffle_xor_sync(j); + } + } + + // + // 2. Strip down reduced value and compute sum of exps + // + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tCrSoftmax_f); ++i) { + tCrSoftmax_f(i) = tCrTopK_f(i).reduce_final(); + } + } + else { + // + // 1. Warp shuffle reduction + // + CUTLASS_PRAGMA_UNROLL + for (int reduction_cols = size<1>(lane_layout_MN) / 2; reduction_cols > 0; reduction_cols /= 2) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tCrTopK_f); ++i) { + tCrTopK_f(i).shuffle_down_sync(lane_layout_MN(_0{},reduction_cols)); + } + } + + // + // 2. Strip down reduced value and compute sum of exps + // + bool is_reduced_lane = get<1>(lane_mn) == 0; + if (is_reduced_lane) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tCrSoftmax_f); ++i) { + tCrSoftmax_f(i) = tCrTopK_f(i).reduce_final(); + } + } + + // + // 3. Broadcast reduced values to all participants + // + CUTLASS_PRAGMA_UNROLL + for (int broadcast_cols = 1; broadcast_cols <= size<1>(lane_layout_MN) / 2; broadcast_cols *= 2) { + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < size(tCrSoftmax_f); ++i) { + tCrSoftmax_f(i).shuffle_up_sync(lane_layout_MN(_0{},broadcast_cols), get<1>(lane_mn)); + } + } + } + + // + // 4. Re-visit and apply top-K and softmax + // + CUTLASS_PRAGMA_UNROLL + for (int epi_v = 0; epi_v < size(visit_results); ++epi_v) { + auto& visit_frag = visit_results(epi_v); + CUTLASS_PRAGMA_UNROLL + for (int i = 0; i < FragmentSize; ++i) { + visit_frag[i] = detail::masked_softmax( + visit_frag[i], + tCrSoftmax(epi_v * FragmentSize + i).min_, + tCrSoftmax(epi_v * FragmentSize + i).logsumexp_ + ); + } + } + + } + + CUTLASS_DEVICE void + end_loop(int epi_m, int epi_n) { + auto& [tCrTopK, tCrSoftmax, tCcCol, cCol, + lane_layout_MN, lane_mn, + residue_cCol, residue_tCcCol] = args_tuple; + + // Reset reduced top-K values for next tile + // This must be done because we only assume a single epilogue tile across N, + // but not M. + fill(tCrTopK, TopKResult()); + } + + CUTLASS_DEVICE void + end() { } + + }; + + template < + bool ReferenceSrc, // do register tensors reference the src or dst layout of the tiled copy + class... Args + > + CUTLASS_DEVICE auto + get_consumer_store_callbacks(ConsumerStoreArgs const& args) { + Layout ref_layout_MN = [&] () { + if constexpr (ReferenceSrc) { return get<0>(args.tiled_copy.get_layoutS_MN()); } + else { return get<0>(args.tiled_copy.get_layoutD_MN()); } + }(); // tile_mn -> tv_idx + + // Get the MN layout + coord of lanes to determine shuffle reduction iterations + using _W = Int; + Layout tv2lane = Layout,_W,_1>,Stride<_1,_0,_0>>{}; // tv_idx -> lane_idx + Layout ref2lane = composition(tv2lane, ref_layout_MN); // tile_mn -> lane_idx + Layout lane_layout_MN = make_layout(filter(get<0>(ref2lane)), filter(get<1>(ref2lane))); // lane_mn -> lane_idx + Layout inv_lane_layout_MN = right_inverse(lane_layout_MN); // lane_idx -> lane_mn + int lane_idx = canonical_lane_idx(); + auto lane_mn = idx2crd(inv_lane_layout_MN(lane_idx), shape(lane_layout_MN)); + + // Get the MN layout + coord of warps to determine smem reduction iterations + Layout tv2warp = Layout,_W,_1>,Stride<_0,_1,_0>>{}; // tv_idx -> warp_idx + Layout ref2warp = composition(tv2warp, ref_layout_MN); // tile_mn -> warp_idx + Layout warp_layout_MN = make_layout(filter(get<0>(ref2warp)), filter(get<1>(ref2warp))); // warp_mn -> warp_idx + + // Make sure there's only one warp across N so we can use warp shuffle intrinsics for reduction. + static_assert(decltype(size<1>(warp_layout_MN))::value <= 1); + + // Reduction layout + // We're assuming all elements in a row (over which we're performing the reduction) are + // visited in the same corresponding epilogue tile, and this is what allows us to apply the + // top-K + softmax operation within `reduce()`, by re-visiting the accumulated results. + // + // This presents a challenge, because the layout of the accumulated results is typically in + // in the register to shared memory shape, or: (R2S,R2S_M,R2S_N). + // This means that we still need to reduce this tensor along N. + // + // The solution is simple: we need to flatten the layout, identify modes that correspond to + // N and set their strides to 0, in order to map fragment indices corresponding to the same + // row back to the same element in the tensor. + // + // This requires some extra layout manipulation, which is as follows. + + // Create new accumulator layout with column broadcast + auto [M, N, K] = args.tile_shape_mnk; + auto thr_mma = args.tiled_mma.get_thread_slice(args.thread_idx); + auto gColReduce = make_tensor( + make_layout(make_shape(M, N), make_stride(_1{}, 0_c))); // (M,N) + auto tCrColReduce = make_tensor_like( // (FrgV, MMA_M, MMA_N) + thr_mma.partition_C(gColReduce).layout()); + + // Tile the new accumulator tensor according to R2S + ThrCopy thread_r2s = args.tiled_copy.get_slice(args.thread_idx); + Tensor tRS_rSoftmax = thread_r2s.retile_S(tCrColReduce); // ((R2S,R2S_V),MMA_M,MMA_N) + auto tCrC_layout = args.tCrC.layout(); // (R2S,R2S_M,R2S_N) + + // Compose the new accumulator R2S layout with the expected tCrC layout to get final + // reduction tensor layout. + auto tCrSoftmax_layout = take<0, 3>(tRS_rSoftmax.layout()).compose(tCrC_layout); // (R2S,R2S_V) o (R2S,R2S_M,R2S_N) + + Tensor tCrTopK = make_tensor(tCrSoftmax_layout); // (R2S,R2S_M,R2S_N) + Tensor tCrSoftmax = make_tensor(tCrSoftmax_layout); // (R2S,R2S_M,R2S_N) + fill(tCrTopK, TopKResult()); + + auto args_tuple = make_tuple( + cute::move(tCrTopK), cute::move(tCrSoftmax), args.tCcD, args.cD, + lane_layout_MN, lane_mn, + args.residue_cD, args.residue_tCcD); + return ConsumerStoreCallbacks(std::move(args_tuple), params); + } +}; + +///////////////////////////////////////////////////////////////////////////////////////////////// + +} // namespace cutlass::epilogue::fusion + +///////////////////////////////////////////////////////////////////////////////////////////////// diff --git a/include/cutlass/epilogue/thread/activation.h b/include/cutlass/epilogue/thread/activation.h index 9763f5fc1e..186e996602 100644 --- a/include/cutlass/epilogue/thread/activation.h +++ b/include/cutlass/epilogue/thread/activation.h @@ -1,5 +1,5 @@ /*************************************************************************************************** - * Copyright (c) 2017 - 2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. + * Copyright (c) 2017 - 2024 NVIDIA CORPORATION & AFFILIATES. All rights reserved. * SPDX-License-Identifier: BSD-3-Clause * * Redistribution and use in source and binary forms, with or without @@ -37,6 +37,7 @@ #include "cutlass/cutlass.h" #include "cutlass/numeric_types.h" +#include "cutlass/numeric_conversion.h" #include "cutlass/constants.h" #include "cutlass/complex.h" #include "cutlass/array.h" @@ -51,23 +52,88 @@ namespace thread { ///////////////////////////////////////////////////////////////////////////////////////////////// +// Identity operator template struct Identity { + static const bool kIsHeavy = false; + CUTLASS_HOST_DEVICE T operator()(T value) const { return value; } }; -///////////////////////////////////////////////////////////////////////////////////////////////// +template +struct Identity > { + CUTLASS_HOST_DEVICE + Array operator()(Array value) const { + return value; + } +}; + +/// Scale operator +template +struct Scale { + struct Arguments { + using scale_type = T; + T scale = T(1); + }; + + CUTLASS_HOST_DEVICE + T operator()(T value, T scale) const { + multiplies mul; + return mul(scale, value); + } + + CUTLASS_HOST_DEVICE + T operator()(T value, Arguments args = Arguments()) const { + return this->operator()(value, args.scale); + } +}; + +template +struct Scale> { + using Arguments = typename Scale::Arguments; + + CUTLASS_HOST_DEVICE + Array operator()(Array values, T scale) const { + multiplies> mul; + return mul(scale, values); + } + + CUTLASS_HOST_DEVICE + Array operator()(Array values, Arguments args = Arguments()) const { + return this->operator()(values, args.scale); + } +}; + +/// Specialization to compose other activations with a defined unary operator +/// e.g. Scale> +template